diff --git a/.github/ISSUE_TEMPLATE/miscellaneous.md b/.github/ISSUE_TEMPLATE/miscellaneous.md index 80ee92ad3..d77c625c3 100644 --- a/.github/ISSUE_TEMPLATE/miscellaneous.md +++ b/.github/ISSUE_TEMPLATE/miscellaneous.md @@ -10,8 +10,7 @@ assignees: '' For anything other than bug reports and feature requests (performance, refactoring, etc), just go ahead and file the issue. Please provide as many details as possible. -If you have a question or a support request, please open a new discussion on [GitHub Discussions](https://github.com/spring-projects-experimental/spring-ai-mcp/discussions) -or ask a question on [StackOverflow](https://stackoverflow.com/questions/tagged/spring-ai-mcp). +If you have a question or a support request, please open a new discussion on [GitHub Discussions](https://github.com/modelcontextprotocol/java-sdk/discussions) -Please do **not** create issues on the [Issue Tracker](https://github.com/spring-projects-experimental/spring-ai-mcp/issues) for questions or support requests. +Please do **not** create issues on the [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) for questions or support requests. We would like to keep the issue tracker **exclusively** for bug reports and feature requests. diff --git a/.github/workflows/artifactory-milestone-release.yml b/.github/workflows/artifactory-milestone-release.yml deleted file mode 100644 index 98697f951..000000000 --- a/.github/workflows/artifactory-milestone-release.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Artifactory Milestone Release - -on: - workflow_dispatch: - inputs: - releaseVersion: - description: "Milestone release version" - required: true - -jobs: - build: - name: Release milestone to Artifactory - runs-on: ubuntu-latest - steps: - - name: Checkout source code - uses: actions/checkout@v4 - - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - java-version: '17' - distribution: 'temurin' - cache: 'maven' - - - name: Capture release version - run: echo RELEASE_VERSION=${{ github.event.inputs.releaseVersion }} >> $GITHUB_ENV - - - name: Update release version - run: mvn versions:set -DgenerateBackupPoms=false -DnewVersion=$RELEASE_VERSION - - - name: Enforce release rules - run: mvn org.apache.maven.plugins:maven-enforcer-plugin:enforce -Drules=requireReleaseDeps - - - name: Build with Maven and deploy to Artifactory's milestone repository - env: - ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} - ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} - run: mvn -P artifactory-milestone -P javadoc -s settings.xml --batch-mode -Dmaven.test.skip=true deploy diff --git a/.github/workflows/artifactory-staging.yml b/.github/workflows/artifactory-staging.yml deleted file mode 100644 index 04493b2c4..000000000 --- a/.github/workflows/artifactory-staging.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Artifactory Staging - -on: - workflow_dispatch: - inputs: - releaseVersion: - description: "Release version" - required: true - -jobs: - build: - name: Stage release to Artifactory - runs-on: ubuntu-latest - steps: - - name: Checkout source code - uses: actions/checkout@v4 - - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - java-version: '17' - distribution: 'temurin' - cache: 'maven' - - - name: Capture release version - run: echo RELEASE_VERSION=${{ github.event.inputs.releaseVersion }} >> $GITHUB_ENV - - - name: Update release version - run: mvn versions:set -DgenerateBackupPoms=false -DnewVersion=$RELEASE_VERSION - - - name: Enforce release rules - run: mvn org.apache.maven.plugins:maven-enforcer-plugin:enforce -Drules=requireReleaseDeps - - - name: Build with Maven and deploy to Artifactory staging repository - env: - ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} - ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} - run: mvn -P artifactory-staging -P javadoc -s settings.xml --batch-mode -Dmaven.test.skip=true deploy diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..7c73d9f38 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,22 @@ +name: CI + +on: + pull_request: {} + +jobs: + build: + name: Build branch + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build + run: mvn verify diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml deleted file mode 100644 index e119a701f..000000000 --- a/.github/workflows/continuous-integration.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: CI/CD build - -on: - push: - branches: [ "main" ] - -jobs: - build: - name: Build branch - runs-on: ubuntu-latest - steps: - - name: Checkout source code - uses: actions/checkout@v4 - - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - java-version: '17' - distribution: 'temurin' - cache: 'maven' - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - - - name: Build with Maven and deploy to Artifactory - env: - ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} - ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} - run: | - mvn -s settings.xml -Pjavadoc --batch-mode --update-snapshots deploy - - - name: Generate Java docs - run: mvn -Pjavadoc -B javadoc:aggregate - - - name: Capture project version - run: echo PROJECT_VERSION=$(mvn help:evaluate -Dexpression=project.version --quiet -DforceStdout) >> $GITHUB_ENV - -# - name: Generate assembly -# working-directory: spring-ai-docs -# run: mvn assembly:single - - -# - name: Setup SSH key -# env: -# DOCS_SSH_KEY: ${{ secrets.DOCS_SSH_KEY }} -# DOCS_SSH_HOST_KEY: ${{ secrets.DOCS_SSH_HOST_KEY }} -# run: | -# mkdir "$HOME/.ssh" -# echo "$DOCS_SSH_KEY" > "$HOME/.ssh/key" -# chmod 600 "$HOME/.ssh/key" -# echo "$DOCS_SSH_HOST_KEY" > "$HOME/.ssh/known_hosts" -# -# - name: Deploy docs -# env: -# DOCS_HOST: ${{ secrets.DOCS_HOST }} -# DOCS_PATH: ${{ secrets.DOCS_PATH }} -# DOCS_USERNAME: ${{ secrets.DOCS_USERNAME }} -# working-directory: spring-ai-docs/target -# run: | -# unzip spring-ai-$PROJECT_VERSION-docs.zip -# ssh -i $HOME/.ssh/key $DOCS_USERNAME@$DOCS_HOST "cd $DOCS_PATH && mkdir -p $PROJECT_VERSION" -# scp -i $HOME/.ssh/key -r api $DOCS_USERNAME@$DOCS_HOST:$DOCS_PATH/$PROJECT_VERSION - diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml deleted file mode 100644 index de69a219a..000000000 --- a/.github/workflows/deploy-docs.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Deploy Docs -on: - workflow_dispatch: - push: - branches: [main, '[0-9].[0-9].x' ] - tags: ['v[0-9].[0-9].[0-9]', 'v[0-9].[0-9].[0-9]-*'] -permissions: - actions: write -jobs: - build: - runs-on: ubuntu-latest - if: github.repository_owner == 'spring-projects-experimental' - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - ref: docs-build - fetch-depth: 1 - - name: Dispatch (partial build) - if: github.ref_type == 'branch' - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh workflow run deploy-docs.yml -r $(git rev-parse --abbrev-ref HEAD) -f build-refname=${{ github.ref_name }} - - name: Dispatch (full build) - if: github.ref_type == 'tag' - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh workflow run deploy-docs.yml -r $(git rev-parse --abbrev-ref HEAD) \ No newline at end of file diff --git a/.github/workflows/maven-central-release.yml b/.github/workflows/maven-central-release.yml index d3de02640..c6c9d3ab6 100644 --- a/.github/workflows/maven-central-release.yml +++ b/.github/workflows/maven-central-release.yml @@ -1,62 +1,41 @@ -name: Maven Central Release +name: Release to Maven Central on: workflow_dispatch: - inputs: - releaseVersion: - description: "Release version" - required: true jobs: - build: + publish: runs-on: ubuntu-latest steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + server-id: central + server-username: MAVEN_USERNAME + server-password: MAVEN_PASSWORD + gpg-private-key: ${{ secrets.GPG_SECRET_KEY }} + gpg-passphrase: MAVEN_GPG_PASSPHRASE - - name: Capture release version - run: echo RELEASE_VERSION=${{ github.event.inputs.releaseVersion }} >> $GITHUB_ENV + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Build and Test + run: mvn clean verify - - name: Prepare directory structure + - name: Publish to Maven Central run: | - mkdir -p nexus/org/springframework/experimental/mcp-parent/$RELEASE_VERSION - mkdir -p nexus/org/springframework/experimental/mcp/$RELEASE_VERSION - mkdir -p nexus/org/springframework/experimental/spring-ai-mcp/$RELEASE_VERSION - - - name: Download release files from Artifactory + mvn --batch-mode \ + -Prelease \ + -Pjavadoc \ + deploy env: - ARTIFACTORY_URL: "https://repo.spring.io/libs-staging-local/org/springframework/experimental" - ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} - ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} - run: | - echo "Downloading parent POM artifacts" - cd nexus/org/springframework/experimental/spring-ai-mcp/$RELEASE_VERSION - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/spring-ai-mcp/$RELEASE_VERSION/spring-ai-mcp-$RELEASE_VERSION.pom - - echo "Downloading core artifacts" - cd ../../../../../.. - cd nexus/org/springframework/experimental/mcp/$RELEASE_VERSION - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/mcp/$RELEASE_VERSION/mcp-$RELEASE_VERSION.pom - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/mcp/$RELEASE_VERSION/mcp-$RELEASE_VERSION.jar - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/mcp/$RELEASE_VERSION/mcp-$RELEASE_VERSION-javadoc.jar - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/mcp/$RELEASE_VERSION/mcp-$RELEASE_VERSION-sources.jar - - echo "Downloading spring artifacts" - cd ../../../../../.. - cd nexus/org/springframework/experimental/spring-ai-mcp/$RELEASE_VERSION - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/spring-ai-mcp/$RELEASE_VERSION/spring-ai-mcp-$RELEASE_VERSION.pom - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/spring-ai-mcp/$RELEASE_VERSION/spring-ai-mcp-$RELEASE_VERSION.jar - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/spring-ai-mcp/$RELEASE_VERSION/spring-ai-mcp-$RELEASE_VERSION-javadoc.jar - wget --user="$ARTIFACTORY_USERNAME" --password="$ARTIFACTORY_PASSWORD" $ARTIFACTORY_URL/spring-ai-mcp/$RELEASE_VERSION/spring-ai-mcp-$RELEASE_VERSION-sources.jar - - - name: Sign artifacts and release them to Maven Central - uses: spring-io/nexus-sync-action@main - id: nexus - with: - url: ${{ secrets.OSSRH_URL }} - username: ${{ secrets.OSSRH_S01_TOKEN_USERNAME }} - password: ${{ secrets.OSSRH_S01_TOKEN_PASSWORD }} - staging-profile-name: ${{ secrets.OSSRH_STAGING_PROFILE_NAME }} - create: true - upload: true - close: true - release: true - generate-checksums: true \ No newline at end of file + MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} + MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} + MAVEN_GPG_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} diff --git a/.github/workflows/publish-snapshot.yml b/.github/workflows/publish-snapshot.yml new file mode 100644 index 000000000..5d9b4aa39 --- /dev/null +++ b/.github/workflows/publish-snapshot.yml @@ -0,0 +1,44 @@ +name: Publish Snapshot + +on: + push: + branches: [ "main" ] + +jobs: + build: + name: Build branch + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + server-id: central + server-username: MAVEN_USERNAME + server-password: MAVEN_PASSWORD + gpg-private-key: ${{ secrets.GPG_SECRET_KEY }} + gpg-passphrase: MAVEN_GPG_PASSPHRASE + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Generate Java docs + run: mvn -Pjavadoc -B javadoc:aggregate + + - name: Build with Maven and deploy to Sonatype snapshot repository + env: + MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} + MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} + MAVEN_GPG_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} + run: | + mvn -Pjavadoc -Prelease --batch-mode --update-snapshots deploy + + - name: Capture project version + run: echo PROJECT_VERSION=$(mvn help:evaluate -Dexpression=project.version --quiet -DforceStdout) >> $GITHUB_ENV diff --git a/LICENSE b/LICENSE index 261eeb9e9..5264c5bcc 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2025 the original author or authors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index f6116ae9d..ca87736cd 100644 --- a/README.md +++ b/README.md @@ -1,131 +1,36 @@ -# Java & Spring MCP -[![Build Status](https://github.com/spring-projects-experimental/spring-ai-mcp/actions/workflows/continuous-integration.yml/badge.svg)](https://github.com/spring-projects-experimental/spring-ai-mcp/actions/workflows/continuous-integration.yml) +# MCP Java SDK +[![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) -Set of projects that provide Java SDK and Spring Framework integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). -It enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. +A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). +This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. ## 📚 Reference Documentation -For comprehensive guides and API documentation, visit the [Spring AI MCP Reference Documentation](https://docs.spring.io/spring-ai-mcp/reference/overview.html). - - - - -## Projects - -### [MCP Java SDK](https://docs.spring.io/spring-ai-mcp/reference/mcp.html) - -Java implementation of the Model Context Protocol specification. It includes: -- Synchronous and asynchronous [MCP Client](https://docs.spring.io/spring-ai-mcp/reference/mcp.html#_mcp_client) and [MCP Server](https://docs.spring.io/spring-ai-mcp/reference/mcp.html#_mcp_server) implementations -- Standard MCP operations support (tool discovery, resource management, prompt handling, structured logging). Support for request and notification handling. -- [Stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) and [SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport implementations. - -### MCP Transports - -#### Core Transports -- Stdio-based (`StdioClientTransport`, `StdioServerTransport`) for process-based communication -- Java HttpClient-based SSE client (`HttpClientSseClientTransport`) for HTTP streaming -- Servlet-based SSE server (`HttpServletSseServerTransport`) for HTTP SSE Server streaming using traditional Servlet API - -#### Optional SSE Transports -- [WebFlux SSE Transport](https://github.com/spring-projects-experimental/spring-ai-mcp/tree/main/mcp-transport/mcp-webflux-sse-transport) - Reactive HTTP streaming with Spring WebFlux (Client & Server) -- [WebMvc SSE Transport](https://github.com/spring-projects-experimental/spring-ai-mcp/tree/main/mcp-transport/mcp-webmvc-sse-transport) - Spring MVC based HTTP SSE transport (Server only). -You can use the core `HttpClientSseClientTransport` transport as a SSE client. - -### [Spring AI MCP](https://docs.spring.io/spring-ai-mcp/reference/spring-mcp.html) - -The Spring integration module provides Spring-specific functionality: -- Integration with Spring AI's function calling system -- Spring-friendly abstractions for MCP clients -- Auto-configurations (WIP) - - -## Installation - -Add the following dependencies to your Maven project: - -```xml - - - org.springframework.experimental - mcp - - - - - org.springframework.experimental - mcp-webflux-sse-transport - - - - - org.springframework.experimental - mcp-webmvc-sse-transport - - - - - org.springframework.experimental - spring-ai-mcp - -``` - -This is a milestone release, not available on Maven Central. -Add this repository to your POM: - -```xml - - - spring-milestones - Spring Milestones - https://repo.spring.io/milestone - - false - - - -``` - -Reffer to the [Dependency Management](https://docs.spring.io/spring-ai-mcp/reference/dependency-management.html) page for more information. - -## Example Demos - -Explore these MCP examples in the [spring-ai-examples/model-context-protocol](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol) repository: - -- [SQLite Simple](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/sqlite/simple) - Demonstrates LLM integration with a database -- [SQLite Chatbot](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/sqlite/chatbot) - Interactive chatbot with SQLite database interaction -- [Filesystem](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/filesystem) - Enables LLM interaction with local filesystem folders and files -- [Brave](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/brave) - Enables natural language interactions with Brave Search, allowing you to perform internet searches. -- [Theme Park API Example](https://github.com/habuma/spring-ai-examples/tree/main/spring-ai-mcp) - Shows how to create an MCP server and client with Spring AI, exposing Theme Park API tools -- [Http SSE Client + WebMvc SSE Server](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-webmvc-server) - Showcases how to create and use MCP WebMvc servers and HttpClient clients with different capabilities. -- [WebFlux SSE Client + WebFlux SSE Server](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-webflux-server) - Showcases how to create and use MCP WebFlux servers and clients with different capabilities -- [HttpClient SSE Client + Servlet SSE Server](https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-servlet-server) - Showcases how to create and use MCP Servlet SSE Server and HttpClient SSE Client with different capabilities - -## Documentation - -- [Java MCP SDK documentation](mcp/README.md) - - [Reference documentation](docs/ref-index.md) -- [Spring Integration documentation](spring-ai-mcp/README.md) +#### MCP Java SDK documentation +For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview). +#### Spring AI MCP documentation +[Spring AI MCP](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-server-boot-starter-docs.html) starters. Bootstrap your AI applications with MCP support using [Spring Initializer](https://start.spring.io). ## Development -- Building from Source +### Building from Source ```bash -mvn clean install +./mvnw clean install -DskipTests ``` -- Running Tests +### Running Tests + +To run the tests you have to pre-install `Docker` and `npx`. ```bash -mvn test +./mvnw test ``` - ## Contributing -This is an experimental Spring project. Contributions are welcome! Please: +Contributions are welcome! Please: 1. Fork the repository 2. Create a feature branch @@ -138,10 +43,10 @@ This is an experimental Spring project. Contributions are welcome! Please: ## Links -- [GitHub Repository](https://github.com/spring-projects-experimental/spring-ai-mcp) -- [Issue Tracker](https://github.com/spring-projects-experimental/spring-ai-mcp/issues) -- [CI/CD](https://github.com/spring-projects-experimental/spring-ai-mcp/actions) +- [GitHub Repository](https://github.com/modelcontextprotocol/java-sdk) +- [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) +- [CI/CD](https://github.com/modelcontextprotocol/java-sdk/actions) ## License -This project is licensed under the [Apache License 2.0](LICENSE). +This project is licensed under the [MIT License](LICENSE). diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 9abd5eb1e..5b37a4394 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -5,61 +5,55 @@ 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 mcp-bom pom - Spring AI MCP BOM - Spring AI MCP Bill of Materials + Java SDK MCP BOM + Java SDK MCP Bill of Materials - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com/spring-projects-experimental/spring-ai-mcp.git - + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + - org.springframework.experimental + io.modelcontextprotocol.sdk mcp ${project.version} - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-test ${project.version} - org.springframework.experimental - mcp-webflux-sse-transport + io.modelcontextprotocol.sdk + mcp-spring-webflux ${project.version} - org.springframework.experimental - mcp-webmvc-sse-transport + io.modelcontextprotocol.sdk + mcp-spring-webmvc ${project.version} - - - org.springframework.experimental - spring-ai-mcp - ${project.version} - diff --git a/mcp-docs/0.5.0-BREAKING-CHANGES.md b/mcp-docs/0.5.0-BREAKING-CHANGES.md deleted file mode 100644 index 298c917df..000000000 --- a/mcp-docs/0.5.0-BREAKING-CHANGES.md +++ /dev/null @@ -1,105 +0,0 @@ -# Breaking Changes for Spring AI MCP 0.5.0-SNAPSHOT - -## Major Changes - -1. Transport Implementation Modularization - - SSE transport implementations have been moved to dedicated modules: - - `mcp-webflux-sse-transport`: WebFlux-based SSE transport - - `mcp-webmvc-sse-transport`: WebMVC-based SSE transport - - Base MCP module no longer includes transport implementations - -2. Class Renames - - `SseClientTransport` → `WebFluxSseClientTransport` - - `SseServerTransport` → `WebFluxSseServerTransport` - -3. Project Structure - - Test utilities moved to new `mcp-test` module - - All samples moved to the the https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol - -## Migration Guide - -### 1. Update Dependencies - -If using SSE transport, add the appropriate transport module to your pom.xml: - -For WebFlux SSE transport: -```xml - - org.springframework.experimental - mcp-webflux-sse-transport - 0.5.0-SNAPSHOT - -``` - -For WebMVC SSE transport: -```xml - - org.springframework.experimental - mcp-webmvc-sse-transport - 0.5.0-SNAPSHOT - -``` - -### 2. Update Transport Class References - -Replace: -```java -import org.springframework.ai.mcp.client.transport.SseClientTransport; -import org.springframework.ai.mcp.server.transport.SseServerTransport; -``` - -With: -```java -// For WebFlux -import org.springframework.ai.mcp.client.transport.WebFluxSseClientTransport; -import org.springframework.ai.mcp.server.transport.WebFluxSseServerTransport; - -// Or for WebMVC -import org.springframework.ai.mcp.server.transport.WebMvcSseServerTransport; -``` - -### 3. Update Transport Instantiation - -Replace: -```java -var transport = new SseClientTransport(webClientBuilder); -``` - -With: -```java -var transport = new WebFluxSseClientTransport(webClientBuilder); -``` - -Replace: -```java -var serverTransport = new SseServerTransport(objectMapper, messageEndpoint); -``` - -With: -```java -// For WebFlux -var serverTransport = new WebFluxSseServerTransport(objectMapper, messageEndpoint); - -// Or for WebMVC -var serverTransport = new WebMvcSseServerTransport(objectMapper, messageEndpoint); -``` - -### 4. Test Dependencies - -If you're using MCP test utilities, add the test module dependency: - -```xml - - org.springframework.experimental - mcp-test - 0.5.0-SNAPSHOT - test - -``` - -## Other Changes - -- Spring AI dependency updated to 1.0.0-M5 -- Improved dependency management and reduced transitive dependencies -- Sample application restructured under samples/ directory -- New test utilities and base test classes in mcp-test module diff --git a/mcp-docs/0.6.0-MIGRATION-GUIDE.md b/mcp-docs/0.6.0-MIGRATION-GUIDE.md deleted file mode 100644 index 8f3b1576c..000000000 --- a/mcp-docs/0.6.0-MIGRATION-GUIDE.md +++ /dev/null @@ -1,239 +0,0 @@ -# Spring AI MCP 0.6.0 Migration Guide - -This guide outlines the steps required to migrate your code to Spring AI MCP 0.6.0. - -## Key Changes - -- New builder patterns for `McpClient` and `McpServer` -- Introduction of dedicated feature classes for sync/async operations -- Enhanced type safety and reactive support -- Deprecated methods and classes marked for removal - -## Client Migration - -### Creating Clients - -Before: -```java -// Sync client -McpClient.using(transport) - .requestTimeout(Duration.ofSeconds(5)) - .sync(); - -// Async client -McpClient.using(transport) - .requestTimeout(Duration.ofSeconds(5)) - .async(); -``` - -After: -```java -// Sync client -McpClient.sync(transport) - .requestTimeout(Duration.ofSeconds(5)) - .build(); - -// Async client -McpClient.async(transport) - .requestTimeout(Duration.ofSeconds(5)) - .build(); -``` - -### Change Consumers - -Before: -```java -// Sync client -McpClient.using(transport) - .toolsChangeConsumer(tools -> handleTools(tools)) - .resourcesChangeConsumer(resources -> handleResources(resources)) - .promptsChangeConsumer(prompts -> handlePrompts(prompts)) - .sync(); - -// Async client -McpClient.using(transport) - .toolsChangeConsumer(tools -> handleTools(tools)) - .resourcesChangeConsumer(resources -> handleResources(resources)) - .promptsChangeConsumer(prompts -> handlePrompts(prompts)) - .async(); -``` - -After: -```java -// Sync client -McpClient.sync(transport) - .toolsChangeConsumer(tools -> handleTools(tools)) - .resourcesChangeConsumer(resources -> handleResources(resources)) - .promptsChangeConsumer(prompts -> handlePrompts(prompts)) - .build(); - -// Async client -McpClient.async(transport) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> handleTools(tools))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> handleResources(resources))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> handlePrompts(prompts))) - .build(); -``` - -### Sampling Handlers - -Before: -```java -// Sync client -McpClient.using(transport) - .sampling(request -> new CreateMessageResult("response")) - .sync(); - -// Async client -McpClient.using(transport) - .sampling(request -> new CreateMessageResult("response")) - .async(); -``` - -After: -```java -// Sync client -McpClient.sync(transport) - .sampling(request -> new CreateMessageResult("response")) - .build(); - -// Async client -McpClient.async(transport) - .sampling(request -> Mono.just(new CreateMessageResult("response"))) - .build(); -``` - -## Server Migration - -### Creating Servers - -Before: -```java -// Sync server -McpServer.using(transport) - .serverInfo("test-server", "1.0.0") - .sync(); - -// Async server -McpServer.using(transport) - .serverInfo("test-server", "1.0.0") - .async(); -``` - -After: -```java -// Sync server -McpServer.sync(transport) - .serverInfo("test-server", "1.0.0") - .build(); - -// Async server -McpServer.async(transport) - .serverInfo("test-server", "1.0.0") - .build(); -``` - -### Tool Registration - -Before: -```java -// Using ToolRegistration record -new ToolRegistration( - new Tool("calculator", "Performs calculations", schema), - args -> new CallToolResult("result") -); -``` - -After: -```java -// Sync server -new McpServerFeatures.SyncToolRegistration( - new Tool("calculator", "Performs calculations", schema), - args -> new CallToolResult("result") -); - -// Async server -new McpServerFeatures.AsyncToolRegistration( - new Tool("calculator", "Performs calculations", schema), - args -> Mono.just(new CallToolResult("result")) -); -``` - -### Resource Registration - -Before: -```java -// Using ResourceRegistration record -new ResourceRegistration( - new Resource("docs", "Documentation", "text/markdown"), - request -> new ReadResourceResult(content) -); -``` - -After: -```java -// Sync server -new McpServerFeatures.SyncResourceRegistration( - new Resource("docs", "Documentation", "text/markdown"), - request -> new ReadResourceResult(content) -); - -// Async server -new McpServerFeatures.AsyncResourceRegistration( - new Resource("docs", "Documentation", "text/markdown"), - request -> Mono.just(new ReadResourceResult(content)) -); -``` - -### Prompt Registration - -Before: -```java -// Using PromptRegistration record -new PromptRegistration( - new Prompt("analyze", "Code analysis"), - request -> new GetPromptResult("result") -); -``` - -After: -```java -// Sync server -new McpServerFeatures.SyncPromptRegistration( - new Prompt("analyze", "Code analysis"), - request -> new GetPromptResult("result") -); - -// Async server -new McpServerFeatures.AsyncPromptRegistration( - new Prompt("analyze", "Code analysis"), - request -> Mono.just(new GetPromptResult("result")) -); -``` - -## Spring Integration Changes - -### Tool Helper Changes - -Before: -```java -ToolHelper.toToolRegistration(functionCallback); -ToolHelper.toToolRegistration(functionCallbacks); -``` - -After: -```java -ToolHelper.toSyncToolRegistration(functionCallback); -ToolHelper.toSyncToolRegistration(functionCallbacks); -``` - -## Deprecated APIs - -The following APIs are deprecated and will be removed in a future release: - -- `McpClient.using()` - Use `McpClient.sync()` or `McpClient.async()` instead -- `McpServer.using()` - Use `McpServer.sync()` or `McpServer.async()` instead -- `McpServer.ToolRegistration` - Use `McpServerFeatures.SyncToolRegistration` or `McpServerFeatures.AsyncToolRegistration` instead -- `McpServer.ResourceRegistration` - Use `McpServerFeatures.SyncResourceRegistration` or `McpServerFeatures.AsyncResourceRegistration` instead -- `McpServer.PromptRegistration` - Use `McpServerFeatures.SyncPromptRegistration` or `McpServerFeatures.AsyncPromptRegistration` instead -- `ToolHelper.toToolRegistration()` - Use `ToolHelper.toSyncToolRegistration()` instead diff --git a/mcp-docs/pom.xml b/mcp-docs/pom.xml deleted file mode 100644 index 9f0d61b1d..000000000 --- a/mcp-docs/pom.xml +++ /dev/null @@ -1,89 +0,0 @@ - - - - - 4.0.0 - - org.springframework.experimental - mcp-parent - 0.6.0 - - mcp-docs - Spring AI MCP Docs - Spring AI MCP documentation - - - - - org.antora - antora-maven-plugin - ${org.maven.antora-version} - true - - - - - - - src/main/antora/antora-playbook.yml - - @antora/cli@3.2.0-alpha.6 - @antora/atlas-extension@1.0.0-alpha.2 - @antora/collector-extension@1.0.0-beta.1 - @asciidoctor/tabs@1.0.0-beta.6 - @springio/antora-extensions@1.14.2 - @springio/asciidoctor-extensions@1.0.0-alpha.12 - @djencks/asciidoctor-mathjax@0.0.9 - - - - - io.spring.maven.antora - antora-component-version-maven-plugin - ${io.spring.maven.antora-version} - - - - antora-component-version - - - - - - org.apache.maven.plugins - maven-assembly-plugin - ${maven-assembly-plugin.version} - - - src/assembly/javadocs.xml - - spring-ai-mcp-${project.version} - true - - - - org.apache.maven.plugins - maven-deploy-plugin - ${maven-deploy-plugin.version} - - true - - - - - - \ No newline at end of file diff --git a/mcp-docs/src/assembly/javadocs.xml b/mcp-docs/src/assembly/javadocs.xml deleted file mode 100644 index a12b36e8e..000000000 --- a/mcp-docs/src/assembly/javadocs.xml +++ /dev/null @@ -1,31 +0,0 @@ - - - - docs - - zip - - false - - - ../target/site/apidocs - api - - - \ No newline at end of file diff --git a/mcp-docs/src/main/antora/antora-playbook.yml b/mcp-docs/src/main/antora/antora-playbook.yml deleted file mode 100644 index bb45a0b8e..000000000 --- a/mcp-docs/src/main/antora/antora-playbook.yml +++ /dev/null @@ -1,48 +0,0 @@ -# PACKAGES antora@3.2.0-alpha.6 @antora/atlas-extension:1.0.0-alpha.1 @antora/collector-extension@1.0.0-alpha.3 @springio/antora-extensions@1.1.0-alpha.2 @asciidoctor/tabs@1.0.0-alpha.12 @opendevise/antora-release-line-extension@1.0.0-alpha.2 -# -# The purpose of this Antora playbook is to build the docs in the current branch. -antora: - extensions: - - '@antora/collector-extension' - # - require: '@springio/antora-extensions/root-component-extension' - - require: '@springio/antora-extensions' - root_component_name: 'ai' -site: - title: Spring AI Reference - url: https://docs.spring.io/spring-ai-mcp/reference - robots: allow -git: - ensure_git_suffix: false -content: - sources: - - url: ./../../../.. - branches: HEAD - start_path: mcp-docs/src/main/antora - worktrees: true -asciidoc: - attributes: - page-related-doc-categories: ai,java,ml,mcp - page-pagination: '' - hide-uri-scheme: '@' - tabs-sync-option: '@' - chomp: 'all' - stem: 'asciimath' - extensions: - - '@asciidoctor/tabs' - - '@springio/asciidoctor-extensions' - - '@springio/asciidoctor-extensions/javadoc-extension' - - '@springio/asciidoctor-extensions/include-code-extension' - - '@djencks/asciidoctor-mathjax' - sourcemap: true -urls: - latest_version_segment_strategy: redirect:to - latest_version_segment: '' - redirect_facility: httpd -runtime: - log: - failure_level: warn - format: pretty -ui: - bundle: - url: https://github.com/spring-io/antora-ui-spring/releases/download/v0.4.17/ui-bundle.zip - snapshot: true \ No newline at end of file diff --git a/mcp-docs/src/main/antora/antora.yml b/mcp-docs/src/main/antora/antora.yml deleted file mode 100644 index 916c760c5..000000000 --- a/mcp-docs/src/main/antora/antora.yml +++ /dev/null @@ -1,12 +0,0 @@ -name: ai -version: true -title: Spring AI -nav: - - modules/ROOT/nav.adoc -ext: - collector: - - run: - command: mvnw process-resources - local: true - scan: - dir: mcp-docs/target/classes/antora-resources \ No newline at end of file diff --git a/mcp-docs/src/main/antora/modules/ROOT/images/class-diagrams.puml b/mcp-docs/src/main/antora/modules/ROOT/images/class-diagrams.puml deleted file mode 100644 index 5b08c738c..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/images/class-diagrams.puml +++ /dev/null @@ -1,282 +0,0 @@ -@startuml Core Components - -' Core Interfaces -interface McpTransport { - +Mono connect(Function, Mono> handler) - +Mono sendMessage(JSONRPCMessage message) - +void close() - +Mono closeGracefully() - + T unmarshalFrom(Object data, TypeReference typeRef) -} - -interface McpSession { - + Mono sendRequest(String method, Object requestParams, TypeReference typeRef) - +Mono sendNotification(String method, Map params) - +Mono closeGracefully() - +void close() -} - -' Core Implementation Classes -class DefaultMcpSession { - +interface RequestHandler - +interface NotificationHandler -} - -' Client Classes -class McpClient { - +{static} Builder using(ClientMcpTransport transport) -} - -class McpAsyncClient { - +Mono initialize() - +ServerCapabilities getServerCapabilities() - +Implementation getServerInfo() - +ClientCapabilities getClientCapabilities() - +Implementation getClientInfo() - +void close() - +Mono closeGracefully() - +Mono ping() - +Mono addRoot(Root root) - +Mono removeRoot(String rootUri) - +Mono rootsListChangedNotification() - +Mono callTool(CallToolRequest request) - +Mono listTools() - +Mono listResources() - +Mono readResource(ReadResourceRequest request) - +Mono listResourceTemplates() - +Mono subscribeResource(SubscribeRequest request) - +Mono unsubscribeResource(UnsubscribeRequest request) - +Mono listPrompts() - +Mono getPrompt(GetPromptRequest request) - +Mono setLoggingLevel(LoggingLevel level) -} - -class McpSyncClient { - +InitializeResult initialize() - +ServerCapabilities getServerCapabilities() - +Implementation getServerInfo() - +ClientCapabilities getClientCapabilities() - +Implementation getClientInfo() - +void close() - +boolean closeGracefully() - +Object ping() - +void addRoot(Root root) - +void removeRoot(String rootUri) - +void rootsListChangedNotification() - +CallToolResult callTool(CallToolRequest request) - +ListToolsResult listTools() - +ListResourcesResult listResources() - +ReadResourceResult readResource(ReadResourceRequest request) - +ListResourceTemplatesResult listResourceTemplates() - +void subscribeResource(SubscribeRequest request) - +void unsubscribeResource(UnsubscribeRequest request) - +ListPromptsResult listPrompts() - +GetPromptResult getPrompt(GetPromptRequest request) - +void setLoggingLevel(LoggingLevel level) -} - -' Server Classes -class McpServer { - +{static} Builder using(ServerMcpTransport transport) -} - -class McpAsyncServer { - - +ServerCapabilities getServerCapabilities() - +Implementation getServerInfo() - +ClientCapabilities getClientCapabilities() - +Implementation getClientInfo() - +void close() - +Mono closeGracefully() - - ' Tool Management - +Mono addTool(ToolRegistration toolRegistration) - +Mono removeTool(String toolName) - +Mono notifyToolsListChanged() - - ' Resource Management - +Mono addResource(ResourceRegistration resourceHandler) - +Mono removeResource(String resourceUri) - +Mono notifyResourcesListChanged() - - ' Prompt Management - +Mono addPrompt(PromptRegistration promptRegistration) - +Mono removePrompt(String promptName) - +Mono notifyPromptsListChanged() - - ' Logging - +Mono loggingNotification(LoggingMessageNotification notification) - - ' Sampling - +Mono createMessage(CreateMessageRequest request) -} - -class McpSyncServer { - +McpAsyncServer getAsyncServer() - - +ServerCapabilities getServerCapabilities() - +Implementation getServerInfo() - +ClientCapabilities getClientCapabilities() - +Implementation getClientInfo() - +void close() - +void closeGracefully() - - ' Tool Management - +void addTool(ToolRegistration toolHandler) - +void removeTool(String toolName) - +void notifyToolsListChanged() - - ' Resource Management - +void addResource(ResourceRegistration resourceHandler) - +void removeResource(String resourceUri) - +void notifyResourcesListChanged() - - ' Prompt Management - +void addPrompt(PromptRegistration promptRegistration) - +void removePrompt(String promptName) - +void notifyPromptsListChanged() - - ' Logging - +void loggingNotification(LoggingMessageNotification notification) - - ' Sampling - +CreateMessageResult createMessage(CreateMessageRequest request) -} - -' Transport Implementations -class StdioClientTransport implements ClientMcpTransport { - +void setErrorHandler(Consumer errorHandler) - +Sinks.Many getErrorSink() -} - -class StdioServerTransport implements ServerMcpTransport { -} - - -class HttpServletSseServerTransport implements ServerMcpTransport { -} - - -class HttpClientSseClientTransport implements ClientMcpTransport { -} - - -class WebFluxSseClientTransport implements ClientMcpTransport { -} - - -class WebFluxSseServerTransport implements ServerMcpTransport { - +RouterFunction getRouterFunction() -} - -class WebMvcSseServerTransport implements ServerMcpTransport { - +RouterFunction getRouterFunction() -} - - -' Schema and Error Classes -class McpSchema { - +class ErrorCodes - +interface Request - +interface JSONRPCMessage - +interface ResourceContents - +interface Content - +interface ServerCapabilities - +{static} JSONRPCMessage deserializeJsonRpcMessage() -} - -class McpError { -} - -' Relationships -McpTransport <|.. ClientMcpTransport -McpTransport <|.. ServerMcpTransport - -McpSession <|.. DefaultMcpSession -DefaultMcpSession --o McpAsyncClient -DefaultMcpSession --o McpAsyncServer - -McpClient ..> McpAsyncClient : creates -McpClient ..> McpSyncClient : creates -McpSyncClient --> McpAsyncClient : delegates to - -McpServer ..> McpAsyncServer : creates -McpServer ..> McpSyncServer : creates -McpSyncServer o-- McpAsyncServer - -DefaultMcpSession o-- McpTransport -McpSchema <.. McpSession : uses -McpError ..> McpSession : throws - -@enduml - -@startuml Message Flow - -package "MCP Schema" { - interface JSONRPCMessage { - +String jsonrpc() - } - - interface Request { - } - - class InitializeRequest - class CallToolRequest - class ListToolsRequest - class ListResourcesRequest - class ReadResourceRequest - class ListResourceTemplatesRequest - class ListPromptsRequest - class GetPromptRequest -} - -package "Resource Types" { - interface ResourceContents { - +String uri() - +String mimeType() - } - - class TextResourceContents - class BlobResourceContents - - interface Content { - +String type() - } - - class TextContent - class ImageContent - class EmbeddedResource - - interface Annotated { - +Annotations annotations() - } - - interface PromptOrResourceReference { - +String type() - } - - class PromptReference - class ResourceReference -} - -JSONRPCMessage <|.. Request -Request <|.. InitializeRequest -Request <|.. CallToolRequest -Request <|.. ListToolsRequest -Request <|.. ListResourcesRequest -Request <|.. ReadResourceRequest -Request <|.. ListResourceTemplatesRequest -Request <|.. ListPromptsRequest -Request <|.. GetPromptRequest - -ResourceContents <|.. TextResourceContents -ResourceContents <|.. BlobResourceContents - -Content <|.. TextContent -Content <|.. ImageContent -Content <|.. EmbeddedResource - -PromptOrResourceReference <|.. PromptReference -PromptOrResourceReference <|.. ResourceReference - -@enduml diff --git a/mcp-docs/src/main/antora/modules/ROOT/images/mcp-stack.svg b/mcp-docs/src/main/antora/modules/ROOT/images/mcp-stack.svg deleted file mode 100644 index 3847eaa8d..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/images/mcp-stack.svg +++ /dev/null @@ -1,197 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-clinet-architecture.jpg b/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-clinet-architecture.jpg deleted file mode 100644 index f858b8d5c..000000000 Binary files a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-clinet-architecture.jpg and /dev/null differ diff --git a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-server-architecture.jpg b/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-server-architecture.jpg deleted file mode 100644 index c6ebea0c3..000000000 Binary files a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-server-architecture.jpg and /dev/null differ diff --git a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-uml-classdiagram.svg b/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-uml-classdiagram.svg deleted file mode 100644 index f83a586e7..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/images/spring-ai-mcp-uml-classdiagram.svg +++ /dev/null @@ -1 +0,0 @@ -McpTransportMono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler)Mono<Void> sendMessage(JSONRPCMessage message)void close()Mono<Void> closeGracefully()<T> T unmarshalFrom(Object data, TypeReference<T> typeRef)McpSession<T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef)Mono<Void> sendNotification(String method, Map<String, Object> params)Mono<Void> closeGracefully()void close()DefaultMcpSessioninterface RequestHandlerinterface NotificationHandlerMcpClientBuilder using(ClientMcpTransport transport)McpAsyncClientMono<InitializeResult> initialize()ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()Mono<Void> closeGracefully()Mono<Object> ping()Mono<Void> addRoot(Root root)Mono<Void> removeRoot(String rootUri)Mono<Void> rootsListChangedNotification()Mono<CallToolResult> callTool(CallToolRequest request)Mono<ListToolsResult> listTools()Mono<ListResourcesResult> listResources()Mono<ReadResourceResult> readResource(ReadResourceRequest request)Mono<ListResourceTemplatesResult> listResourceTemplates()Mono<Void> subscribeResource(SubscribeRequest request)Mono<Void> unsubscribeResource(UnsubscribeRequest request)Mono<ListPromptsResult> listPrompts()Mono<GetPromptResult> getPrompt(GetPromptRequest request)Mono<Void> setLoggingLevel(LoggingLevel level)McpSyncClientInitializeResult initialize()ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()boolean closeGracefully()Object ping()void addRoot(Root root)void removeRoot(String rootUri)void rootsListChangedNotification()CallToolResult callTool(CallToolRequest request)ListToolsResult listTools()ListResourcesResult listResources()ReadResourceResult readResource(ReadResourceRequest request)ListResourceTemplatesResult listResourceTemplates()void subscribeResource(SubscribeRequest request)void unsubscribeResource(UnsubscribeRequest request)ListPromptsResult listPrompts()GetPromptResult getPrompt(GetPromptRequest request)void setLoggingLevel(LoggingLevel level)McpServerBuilder using(ServerMcpTransport transport)McpAsyncServerServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()Mono<Void> closeGracefully() Mono<Void> addTool(ToolRegistration toolRegistration)Mono<Void> removeTool(String toolName)Mono<Void> notifyToolsListChanged() Mono<Void> addResource(ResourceRegistration resourceHandler)Mono<Void> removeResource(String resourceUri)Mono<Void> notifyResourcesListChanged() Mono<Void> addPrompt(PromptRegistration promptRegistration)Mono<Void> removePrompt(String promptName)Mono<Void> notifyPromptsListChanged() Mono<Void> loggingNotification(LoggingMessageNotification notification) Mono<CreateMessageResult> createMessage(CreateMessageRequest request)McpSyncServerMcpAsyncServer getAsyncServer() ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()void closeGracefully() void addTool(ToolRegistration toolHandler)void removeTool(String toolName)void notifyToolsListChanged() void addResource(ResourceRegistration resourceHandler)void removeResource(String resourceUri)void notifyResourcesListChanged() void addPrompt(PromptRegistration promptRegistration)void removePrompt(String promptName)void notifyPromptsListChanged() void loggingNotification(LoggingMessageNotification notification) CreateMessageResult createMessage(CreateMessageRequest request)StdioClientTransportvoid setErrorHandler(Consumer<String> errorHandler)Sinks.Many<String> getErrorSink()ClientMcpTransportStdioServerTransportServerMcpTransportHttpServletSseServerTransportHttpClientSseClientTransportWebFluxSseClientTransportWebFluxSseServerTransportRouterFunction<?> getRouterFunction()WebMvcSseServerTransportRouterFunction<?> getRouterFunction()McpSchemaclass ErrorCodesinterface Requestinterface JSONRPCMessageinterface ResourceContentsinterface Contentinterface ServerCapabilitiesJSONRPCMessage deserializeJsonRpcMessage()McpErrorcreatescreatesdelegates tocreatescreatesusesthrows \ No newline at end of file diff --git a/mcp-docs/src/main/antora/modules/ROOT/nav.adoc b/mcp-docs/src/main/antora/modules/ROOT/nav.adoc deleted file mode 100644 index 115ff344b..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/nav.adoc +++ /dev/null @@ -1,4 +0,0 @@ -* xref:overview.adoc[Overview] -** xref:dependency-management.adoc[Dependency Management] -* xref:mcp.adoc[MCP Java SDK] -* xref:spring-mcp.adoc[Spring AI MCP] diff --git a/mcp-docs/src/main/antora/modules/ROOT/pages/dependency-management.adoc b/mcp-docs/src/main/antora/modules/ROOT/pages/dependency-management.adoc deleted file mode 100644 index 0f01bc4bf..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/pages/dependency-management.adoc +++ /dev/null @@ -1,110 +0,0 @@ -[[dependency-management]] -= Dependency Management - -[[mcp-bom]] -== Bill of Materials (BOM) - -The Bill of Materials (BOM) declares the recommended versions of all the dependencies used by a given release. -Using the BOM from your application’s build script avoids the need for you to specify and maintain the dependency versions yourself. -Instead, the version of the BOM you’re using determines the utilized dependency versions. -It also ensures that you’re using supported and tested versions of the dependencies by default, unless you choose to override them. - -Add the BOM to your project: - -[tabs] -====== -Maven:: -+ -[source,xml,indent=0,subs="verbatim,quotes"] ----- - - - - org.springframework.experimental - mcp-bom - 0.6.0-SNAPSHOT - pom - import - - - ----- - -Gradle:: -+ -[source,groovy,indent=0,subs="verbatim,quotes"] ----- -dependencies { - implementation platform("org.springframework.experimental:mcp-bom:0.6.0-SNAPSHOT") - //... -} ----- -Gradle users can also use the Spring AI MCP BOM by leveraging Gradle (5.0+) native support for declaring dependency constraints using a Maven BOM. -This is implemented by adding a 'platform' dependency handler method to the dependencies section of your Gradle build script. -As shown in the snippet below this can then be followed by version-less declarations of the Starter Dependencies for the one or more spring-ai modules you wish to use, e.g. spring-ai-openai. -====== - -Replace the version number with the version of the BOM you want to use. - -[[dependencies]] -== Available Dependencies - -The following dependencies are available and managed by the BOM: - -=== Core Dependencies - -* `org.springframework.experimental:mcp` - Core MCP library providing the base functionality and APIs for Model Context Protocol implementation. -* `org.springframework.experimental:spring-ai-mcp` - Spring AI integration with MCP, providing Spring-specific features and utilities. - -=== Transport Dependencies - -* `org.springframework.experimental:mcp-webflux-sse-transport` - WebFlux-based Server-Sent Events (SSE) transport implementation for reactive applications. -* `org.springframework.experimental:mcp-webmvc-sse-transport` - WebMVC-based Server-Sent Events (SSE) transport implementation for servlet-based applications. - -=== Testing Dependencies - -* `org.springframework.experimental:mcp-test` - Testing utilities and support for MCP-based applications. - -[[repositories]] -=== Milestone and Snapshot Repositories - -To use the Milestone and Snapshot version, you need to add references to the Spring Milestone and/or Snapshot repositories in your build file. -Add the following repository definitions to your Maven or Gradle build file: - -[tabs] -====== -Maven:: -+ -[source,xml,indent=0,subs="verbatim,quotes"] ----- - - - spring-milestones - Spring Milestones - https://repo.spring.io/milestone - - false - - - - spring-snapshots - Spring Snapshots - https://repo.spring.io/snapshot - - false - - - ----- - -Gradle:: -+ -[source,groovy,indent=0,subs="verbatim,quotes"] ----- -repositories { - mavenCentral() - maven { url 'https://repo.spring.io/milestone' } - maven { url 'https://repo.spring.io/snapshot' } -} ----- -====== diff --git a/mcp-docs/src/main/antora/modules/ROOT/pages/mcp.adoc b/mcp-docs/src/main/antora/modules/ROOT/pages/mcp.adoc deleted file mode 100644 index dd3ed6ff9..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/pages/mcp.adoc +++ /dev/null @@ -1,675 +0,0 @@ -= Java MCP SDK - -Java SDK implementation of the link:https://modelcontextprotocol.io/introduction[Model Context Protocol], enabling seamless integration with language models and AI tools. - -== Features - -* Synchronous and Asynchronous MCP Client and MCP Server implementations -* Standard MCP operations support: -** Protocol link:https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization[version compatibility negotiation] -** link:https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/[Tool] discovery, execution, list change notifications -** link:https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/[Resource] management with URI templates -** link:https://spec.modelcontextprotocol.io/specification/2024-11-05/client/roots/[Roots] list management and notifications -** link:https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/[Prompt] handling and management -** link:https://spec.modelcontextprotocol.io/specification/2024-11-05/client/sampling/[Sampling] support for AI model interactions -* Multiple transport implementations: -** Core transports: -*** Stdio-based transport for process-based communication -*** Java HttpClient-based SSE client transport for HTTP SSE Client-side streaming -*** Servlet-based SSE server transport for HTTP SSE Server streaming -** Spring-based transports: -*** WebFlux SSE transport for reactive HTTP streaming -*** WebMVC SSE transport for servlet-based HTTP streaming - -=== Dependencies - -Add the following dependency to your Maven project: - -[tabs] -====== -Maven:: -+ -The core MCP functionality: -+ -[source,xml] ----- - - org.springframework.experimental - mcp - ----- -+ -For HTTP SSE transport implementations, add one of the following dependencies -(Note that you already have HTTP Client SSE client transport in the core module): -+ -[source,xml] ----- - - - org.springframework.experimental - mcp-webflux-sse-transport - - - - - org.springframework.experimental - mcp-webmvc-sse-transport - ----- - -Gradle:: -+ -The core MCP functionality: -+ -[source,groovy] ----- -dependencies { - implementation 'org.springframework.experimental:mcp' -} ----- -+ -For HTTP SSE transport implementations, add one of the following dependencies -(note that you already have HTTP Client SSE client transport in the core module): -+ -[source,groovy] ----- -// Spring WebFlux-based SSE client and server transport -implementation 'org.springframework.experimental:mcp-webflux-sse-transport' - -// Spring WebMVC-based SSE server transport -implementation 'org.springframework.experimental:mcp-webmvc-sse-transport' ----- -====== - -Reffer to the xref:dependency-management.adoc[Dependency Management] page for more information. - -== Architecture - -image::mcp-stack.svg[width=400,float=right] - -The SDK follows a layered architecture with clear separation of concerns: - -* *Client/Server Layer*: Both use McpSession for sync/async operations, with McpClient handling client-side protocol operations and McpServer managing server-side protocol operations. -* *Session Layer (McpSession)*: Manages communication patterns and state using DefaultMcpSession implementation. -* *Transport Layer (McpTransport)*: Handles JSON-RPC message serialization/deserialization via: -** StdioTransport (stdin/stdout) in the core module -** HTTP SSE transports in dedicated transport modules (Java HttpClient, Spring WebFlux, Spring WebMVC) - -Following class diagram illustrates the layered architecture of the MCP SDK, showing the relationships between core interfaces (McpTransport, McpSession), their implementations, and the client/server components. It highlights how the transport layer connects to sessions, which in turn support both synchronous and asynchronous client/server implementations. - -image::spring-ai-mcp-uml-classdiagram.svg[width=1000] - -Key Interactions: - -* *Client/Server Initialization*: Transport setup, protocol compatibility check, capability negotiation, and implementation details exchange. -* *Message Flow*: JSON-RPC message handling with validation, type-safe response processing, and error handling. -* *Resource Management*: Resource discovery, URI template-based access, subscription system, and content retrieval. -* *Prompt System*: Discovery, parameter-based retrieval, change notifications, and content management. -* *Tool Execution*: Discovery, parameter validation, timeout-aware execution, and result processing. - -[[mcp-client]] -== MCP Client - -The MCP Client is a key component in the Model Context Protocol (MCP) architecture, responsible for establishing and managing connections with MCP servers. It implements the client-side of the protocol, handling: - -* Protocol version negotiation to ensure compatibility with servers -* Capability negotiation to determine available features -* Message transport and JSON-RPC communication -* Tool discovery and execution -* Resource access and management -* Prompt system interactions -* Optional features like roots management and sampling support - -The client provides both synchronous and asynchronous APIs for flexibility in different application contexts. - -[tabs] -====== -Sync API:: -+ -[source,java] ----- -// Create a sync client with custom configuration -McpSyncClient client = McpClient.sync(transport) - .requestTimeout(Duration.ofSeconds(10)) - .capabilities(ClientCapabilities.builder() - .roots(true) // Enable roots capability - .sampling() // Enable sampling capability - .build()) - .sampling(request -> new CreateMessageResult(response)) - .build(); - -// Initialize connection -client.initialize(); - -// List available tools -ListToolsResult tools = client.listTools(); - -// Call a tool -CallToolResult result = client.callTool( - new CallToolRequest("calculator", - Map.of("operation", "add", "a", 2, "b", 3)) -); - -// List and read resources -ListResourcesResult resources = client.listResources(); -ReadResourceResult resource = client.readResource( - new ReadResourceRequest("resource://uri") -); - -// List and use prompts -ListPromptsResult prompts = client.listPrompts(); -GetPromptResult prompt = client.getPrompt( - new GetPromptRequest("greeting", Map.of("name", "Spring")) -); - -// Add/remove roots -client.addRoot(new Root("file:///path", "description")); -client.removeRoot("file:///path"); - -// Close client -client.closeGracefully(); ----- - -Async API:: -+ -[source,java] ----- -// Create an async client with custom configuration -McpAsyncClient client = McpClient.async(transport) - .requestTimeout(Duration.ofSeconds(10)) - .capabilities(ClientCapabilities.builder() - .roots(true) // Enable roots capability - .sampling() // Enable sampling capability - .build()) - .sampling(request -> Mono.just(new CreateMessageResult(response))) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> { - logger.info("Tools updated: {}", tools); - })) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> { - logger.info("Resources updated: {}", resources); - })) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> { - logger.info("Prompts updated: {}", prompts); - })) - .build(); - -// Initialize connection -client.initialize() - .flatMap(initResult -> { - // List available tools - return client.listTools(); - }) - .flatMap(tools -> { - // Call a tool - return client.callTool(new CallToolRequest( - "calculator", - Map.of("operation", "add", "a", 2, "b", 3) - )); - }) - .flatMap(result -> { - // List and read resources - return client.listResources() - .flatMap(resources -> - client.readResource(new ReadResourceRequest("resource://uri")) - ); - }) - .flatMap(resource -> { - // List and use prompts - return client.listPrompts() - .flatMap(prompts -> - client.getPrompt(new GetPromptRequest( - "greeting", - Map.of("name", "Spring") - )) - ); - }) - .flatMap(prompt -> { - // Add/remove roots - return client.addRoot(new Root("file:///path", "description")) - .then(client.removeRoot("file:///path")); - }) - .doFinally(signalType -> { - // Close client - client.closeGracefully().subscribe(); - }) - .subscribe(); ----- -====== - -=== Client Transport - -The transport layer handles the communication between MCP clients and servers, providing different implementations for various use cases. The client transport manages message serialization, connection establishment, and protocol-specific communication patterns. - -[tabs] -====== -STDIO:: -+ -Creates transport for in-process based communication -+ -[source,java] ----- -ServerParameters params = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); -McpTransport transport = new StdioClientTransport(params); ----- -+ -SSE (HttpClient):: -+ -Creates a framework agnostic (pure Java API) SSE client transport. -Included in the core `mcp` module. -+ -[source,java] ----- -McpTransport transport = new HttpClientSseClientTransport("http://your-mcp-server"); ----- -+ -SSE (WebFlux):: -+ -Creates WebFlux-based SSE client transport. -Requires the `mcp-webflux-sse-transport` dependency. -+ -[source,java] ----- -WebClient.Builder webClientBuilder = WebClient.builder() - .baseUrl("http://your-mcp-server"); -McpTransport transport = new WebFluxSseClientTransport(webClientBuilder); ----- -====== - -=== Client Capabilities - -The client can be configured with various capabilities: - -[source,java] ----- -var capabilities = ClientCapabilities.builder() - .roots(true) // Enable filesystem roots support with list changes notifications - .sampling() // Enable LLM sampling support - .build(); ----- - -==== Roots Support - -Roots define the boundaries of where servers can operate within the filesystem: - -[source,java] ----- -// Add a root dynamically -client.addRoot(new Root("file:///path", "description")); - -// Remove a root -client.removeRoot("file:///path"); - -// Notify server of roots changes -client.rootsListChangedNotification(); ----- - -The roots capability allows servers to: - -* Request the list of accessible filesystem roots -* Receive notifications when the root list changes -* Understand which directories and files they have access to - -==== Sampling Support - -Sampling enables servers to request LLM interactions ("completions" or "generations") through the client: - -[source,java] ----- -// Configure sampling handler -Function samplingHandler = request -> { - // Sampling implementation that interfaces with LLM - return new CreateMessageResult(response); -}; - -// Create client with sampling support -var client = McpClient.using(transport) - .capabilities(ClientCapabilities.builder() - .sampling() - .build()) - .sampling(samplingHandler) - .build(); ----- - -This capability allows: - -* Servers to leverage AI capabilities without requiring API keys -* Clients to maintain control over model access and permissions -* Support for both text and image-based interactions -* Optional inclusion of MCP server context in prompts - -[[mcp-server]] -== MCP Server - -The MCP Server is a foundational component in the Model Context Protocol (MCP) architecture that provides tools, resources, and capabilities to clients. It implements the server-side of the protocol, responsible for: - -* Exposing tools that clients can discover and execute -* Managing resources with URI-based access patterns -* Providing prompt templates and handling prompt requests -* Supporting capability negotiation with clients -* Implementing server-side protocol operations -* Managing concurrent client connections -* Providing structured logging and notifications - -The server supports both synchronous and asynchronous APIs, allowing for flexible integration in different application contexts. It can expose various capabilities such as file system operations, AI model interactions, and custom tool implementations. - -[tabs] -====== -Sync API:: -+ -[source,java] ----- -// Create a server with custom configuration -McpSyncServer syncServer = McpServer.sync(transport) - .serverInfo("my-server", "1.0.0") - .capabilities(ServerCapabilities.builder()...build()) - .tools(new McpServerFeatures.SyncToolRegistration(calculatorTool, calculatorHandler)) - .resources(new McpServerFeatures.SyncResourceRegistration(resource, resourceHandler)) - .prompts(new McpServerFeatures.SyncPromptRegistration(prompt, promptHandler)) - .build(); - -// Add a tool handler at runtime -syncServer.addTool(new CalculatorTool()); - -// Remove a tool handler at runtime -syncServer.removeTool("calculator"); - -// Add a resource at runtime -syncServer.addResource(resourceRegistration); - -// Remove a resource at runtime -syncServer.removeResource(resourceUri); - -// Add a prompt at runtime -syncServer.addPrompt(promptRegistration); - -// Remove a prompt at runtime -syncServer.removePrompt(promptName); - -// Graceful shutdown -syncServer.closeGracefully(); ----- - -Async API:: -+ -[source,java] ----- -// Create an async server with custom configuration -McpAsyncServer asyncServer = McpServer.async(transport) - .serverInfo("my-server", "1.0.0") - .capabilities(ServerCapabilities.builder()...build()) - .tools(new McpServerFeatures.AsyncToolRegistration(calculatorTool, args -> Mono.just(calculatorHandler.apply(args)))) - .resources(new McpServerFeatures.AsyncResourceRegistration(resource, req -> Mono.just(resourceHandler.apply(req)))) - .prompts(new McpServerFeatures.AsyncPromptRegistration(prompt, req -> Mono.just(promptHandler.apply(req)))) - .build(); - -// Add a tool handler at runtime -asyncServer.addTool(new CalculatorTool()) - .doOnSuccess(v -> logger.info("Tool added")) - .subscribe(); - -// Remove a tool handler at runtime -asyncServer.removeTool("calculator") - .doOnSuccess(v -> logger.info("Tool removed")) - .subscribe(); - -// Add a resource at runtime -asyncServer.addResource(resourceRegistration) - .doOnSuccess(v -> logger.info("Resource added")) - .subscribe(); - -// Remove a resource at runtime -asyncServer.removeResource(resourceUri) - .doOnSuccess(v -> logger.info("Resource removed")) - .subscribe(); - -// Add a prompt at runtime -asyncServer.addPrompt(promptRegistration) - .doOnSuccess(v -> logger.info("Prompt added")) - .subscribe(); - -// Remove a prompt at runtime -asyncServer.removePrompt(promptName) - .doOnSuccess(v -> logger.info("Prompt removed")) - .subscribe(); - -// Notify clients of changes -asyncServer.notifyToolsListChanged().subscribe(); -asyncServer.notifyResourcesListChanged().subscribe(); -asyncServer.notifyPromptsListChanged().subscribe(); - -// Graceful shutdown -asyncServer.closeGracefully().subscribe(); ----- -====== - -=== Server Transport - -The server transport layer implements the server-side communication protocols, enabling reliable message exchange with MCP clients. It provides implementations for different communication patterns while ensuring proper message handling, connection management, and protocol compliance. - -[tabs] -====== -STDIO:: -+ -Create in-process based transport -+ -[source,java] ----- -StdioServerTransport transport = new StdioServerTransport(new ObjectMapper()); ----- -+ -Provides bidirectional JSON-RPC message handling over standard input/output streams with non-blocking message processing, serialization/deserialization, and graceful shutdown support. - -SSE (WebFlux):: -+ -Creates WebFlux-based SSE server transport. -Requires the `mcp-webflux-sse-transport` dependency. -+ -[source,java] ----- -@Configuration -class McpConfig { - @Bean - WebFluxSseServerTransport webFluxSseServerTransport(ObjectMapper mapper) { - return new WebFluxSseServerTransport(mapper, "/mcp/message"); - } - - @Bean - RouterFunction mcpRouterFunction(WebFluxSseServerTransport transport) { - return transport.getRouterFunction(); - } -} ----- -+ -Implements the MCP HTTP with SSE transport specification, providing: -+ -* Reactive HTTP streaming with WebFlux -* Concurrent client connections through SSE endpoints -* Message routing and session management -* Graceful shutdown capabilities - -SSE (WebMvc):: -+ -Creates WebMvc-based SSE server transport. -Requires the `mcp-webmvc-sse-transport` dependency. -+ -[source,java] ----- -@Configuration -@EnableWebMvc -class McpConfig { - @Bean - WebMvcSseServerTransport webMvcSseServerTransport(ObjectMapper mapper) { - return new WebMvcSseServerTransport(mapper, "/mcp/message"); - } - - @Bean - RouterFunction mcpRouterFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); - } -} ----- -+ -Implements the MCP HTTP with SSE transport specification, providing: -+ -* Servlet-based HTTP streaming with Spring MVC -* Concurrent client connections through SSE endpoints -* Message routing and session management -* Graceful shutdown capabilities - -SSE (Servlet):: -+ -Creates a Servlet-based SSE server transport. -Included in the core `mcp` module. -The `HttpServletSseServerTransport` can be used with any Servlet container. -To using it with a Spring Web application, you can register it as a Servlet bean: -+ -[source,java] ----- -@Configuration -@EnableWebMvc -public class McpServerConfig implements WebMvcConfigurer { - - @Bean - public HttpServletSseServerTransport servletSseServerTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - - @Bean - public ServletRegistrationBean customServletBean(HttpServletSseServerTransport servlet) { - return new ServletRegistrationBean(servlet); - } -} ----- -+ -(You can implment non Spring web container as well link:https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/mcp/src/test/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransportIntegrationTests.java[HttpServletSseServerTransportIntegrationTests]) -+ -Implements the MCP HTTP with SSE transport specification using the traditional Servlet API, providing: -+ -* Asynchronous message handling using Servlet 6.0 async support -* Session management for multiple client connections -* Two types of endpoints: -** SSE endpoint (/sse) for server-to-client events -** Message endpoint (configurable) for client-to-server requests -* Error handling and response formatting -* Graceful shutdown support - -====== - -=== Server Capabilities - -The server can be configured with various capabilities: - -[source,java] ----- -var capabilities = ServerCapabilities.builder() - .resources(false, true) // Resource support with list changes notifications - .tools(true) // Tool support with list changes notifications - .prompts(true) // Prompt support with list changes notifications - .logging() // Enable logging support (enabled by default with loging level INFO) - .build(); ----- - -==== Logging Support - -The server provides structured logging capabilities that allow sending log messages to clients with different severity levels: - -[source,java] ----- -// Send a log message to clients -server.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .logger("custom-logger") - .data("Custom log message") - .build()); ----- - -Clients can control the minimum logging level they receive through the `mcpClient.setLoggingLevel(level)` request. Messages below the set level will be filtered out. -Supported logging levels (in order of increasing severity): DEBUG (0), INFO (1), NOTICE (2), WARNING (3), ERROR (4), CRITICAL (5), ALERT (6), EMERGENCY (7) - - -==== Tool Registration - -[source,java] ----- -// Sync tool registration -var syncToolRegistration = new McpServerFeatures.SyncToolRegistration( - new Tool("calculator", "Basic calculator", Map.of( - "operation", "string", - "a", "number", - "b", "number" - )), - arguments -> { - // Tool implementation - return new CallToolResult(result, false); - } -); - -// Async tool registration -var asyncToolRegistration = new McpServerFeatures.AsyncToolRegistration( - new Tool("calculator", "Basic calculator", Map.of( - "operation", "string", - "a", "number", - "b", "number" - )), - arguments -> { - // Tool implementation - return Mono.just(new CallToolResult(result, false)); - } -); ----- - -==== Resource Registration - -[source,java] ----- -// Sync resource registration -var syncResourceRegistration = new McpServerFeatures.SyncResourceRegistration( - new Resource("custom://resource", "name", "description", "mime-type", null), - request -> { - // Resource read implementation - return new ReadResourceResult(contents); - } -); - -// Async resource registration -var asyncResourceRegistration = new McpServerFeatures.AsyncResourceRegistration( - new Resource("custom://resource", "name", "description", "mime-type", null), - request -> { - // Resource read implementation - return Mono.just(new ReadResourceResult(contents)); - } -); ----- - -==== Prompt Registration - -[source,java] ----- -// Sync prompt registration -var syncPromptRegistration = new McpServerFeatures.SyncPromptRegistration( - new Prompt("greeting", "description", List.of( - new PromptArgument("name", "description", true) - )), - request -> { - // Prompt implementation - return new GetPromptResult(description, messages); - } -); - -// Async prompt registration -var asyncPromptRegistration = new McpServerFeatures.AsyncPromptRegistration( - new Prompt("greeting", "description", List.of( - new PromptArgument("name", "description", true) - )), - request -> { - // Prompt implementation - return Mono.just(new GetPromptResult(description, messages)); - } -); ----- - -== Error Handling - -The SDK provides comprehensive error handling through the McpError class, covering protocol compatibility, transport communication, JSON-RPC messaging, tool execution, resource management, prompt handling, timeouts, and connection issues. This unified error handling approach ensures consistent and reliable error management across both synchronous and asynchronous operations. diff --git a/mcp-docs/src/main/antora/modules/ROOT/pages/overview.adoc b/mcp-docs/src/main/antora/modules/ROOT/pages/overview.adoc deleted file mode 100644 index 8e6dbac5b..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/pages/overview.adoc +++ /dev/null @@ -1,142 +0,0 @@ -= Java & Spring MCP - -Java SDK and Spring Framework integration for the link:https://modelcontextprotocol.org/docs/concepts/architecture[Model Context Protocol], enabling standardized interaction with AI models and tools through both synchronous and asynchronous communication. - -image::spring-ai-mcp-clinet-architecture.jpg[Spring AI MCP Client Architecture,600,float="right",align="left"] -image::spring-ai-mcp-server-architecture.jpg[Spring AI MCP Server Architecture,600,align="right"] - -== Core Components - -=== xref:mcp.adoc[MCP Java SDK] -Core implementation of the Model Context Protocol specification, providing: - -* Synchronous and asynchronous xref:mcp.adoc#mcp-client[Client] and xref:mcp.adoc#mcp-server[Server] implementations -* Tool discovery and execution -* Resource management with URI templates -* Prompt handling and management -* Structured logging -* Request and Notification handling - -=== MCP Transports - -* *Core Transports* -** Stdio-based (`StdioClientTransport`, `StdioServerTransport`) for process-based communication -** Java HttpClient-based SSE client (`HttpClientSseClientTransport`) for HTTP SSE Client-side streaming -** Servlet-based SSE server (`HttpServletSseServerTransport`) for HTTP SSE Server streaming using traditional Servlet API - -* *Optional SSE Transports* -** link:https://github.com/spring-projects-experimental/spring-ai-mcp/tree/main/mcp-transport/mcp-webflux-sse-transport[WebFlux SSE Transport] - Reactive HTTP streaming with Spring WebFlux (Client & Server) -** link:https://github.com/spring-projects-experimental/spring-ai-mcp/tree/main/mcp-transport/mcp-webmvc-sse-transport[WebMvc SSE Transport] - Spring MVC based SSE transport (Server only). -You can use the core `HttpClientSseClientTransport` transport as a SSE client. - -=== xref:spring-mcp.adoc[Spring AI MCP] -Spring integration features: - -* Spring AI tool/function calling system integration -* Bidirectional conversion between Spring AI function callbacks and MCP tools -* JSON schema generation for tool input validation -* Automatic type conversion and error handling -* Spring-friendly MCP client abstractions -* Auto-configurations (WIP) - -== Getting Started - -[tabs] -====== -Maven:: -+ -[source,xml] ----- - - - org.springframework.experimental - mcp - - - - - org.springframework.experimental - mcp-webflux-sse-transport - - - - - org.springframework.experimental - mcp-webmvc-sse-transport - - - - - org.springframework.experimental - spring-ai-mcp - ----- -+ -Add Spring milestone repository: -+ -[source,xml] ----- - - - spring-milestones - Spring Milestones - https://repo.spring.io/milestone - - false - - - ----- - -Gradle:: -+ -[source,groovy] ----- -dependencies { - implementation 'org.springframework.experimental:mcp' // Core - implementation 'org.springframework.experimental:mcp-webflux-sse-transport' // Optional - implementation 'org.springframework.experimental:mcp-webmvc-sse-transport' // Optional - implementation 'org.springframework.experimental:spring-ai-mcp' // Optional -} - -repositories { - maven { url 'https://repo.spring.io/milestone' } -} ----- -====== - -Reffer to the xref:dependency-management.adoc[Dependency Management] page for more information. - -== Examples - -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/sqlite/simple[SQLite Simple] - Basic LLM-database integration -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/sqlite/chatbot[SQLite Chatbot] - Interactive database chatbot -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/filesystem[Filesystem] - LLM interaction with local files -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/brave[Brave] - Natural language Brave Search integration -* link:https://github.com/habuma/spring-ai-examples/tree/main/spring-ai-mcp[Theme Park API] - MCP server/client with Theme Park API tools -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-webflux-server[WebFlux SSE] - WebFlux server/client implementation -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-webmvc-server[WebMvc SSE] - WebMvc server with HttpClient implementation -* link:https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/mcp-servlet-server[Servlet SSE] - SSE Servlet server with HttpClient implementation - -== Documentation - -* xref:mcp.adoc[Java MCP SDK Documentation] -* xref:spring-mcp.adoc[Spring Integration Documentation] - -== Development - -Build from source: -[source,bash] ----- -mvn clean install ----- - -Run tests: -[source,bash] ----- -mvn test ----- - -== License - -This project is licensed under the link:LICENSE[Apache License 2.0]. diff --git a/mcp-docs/src/main/antora/modules/ROOT/pages/spring-mcp.adoc b/mcp-docs/src/main/antora/modules/ROOT/pages/spring-mcp.adoc deleted file mode 100644 index fd28fa95c..000000000 --- a/mcp-docs/src/main/antora/modules/ROOT/pages/spring-mcp.adoc +++ /dev/null @@ -1,162 +0,0 @@ -= Spring AI MCP Spring -:page-title: Spring AI MCP Spring -:doctype: book -:icons: font -:source-highlighter: highlight.js -:toc: left - -Spring Integration module for Model Control Protocol (MCP) that provides Spring-specific functionality for working with MCP clients. - -== Overview - -The `spring-ai-mcp` module is part of the https://github.com/spring-projects-experimental/spring-ai-mcp[Spring AI MCP] project. It provides Spring Framework integration for the Model Control Protocol (MCP), enabling seamless integration of MCP functionality within Spring applications. - -== Features - -* Spring integration for MCP clients -* Bidirectional conversion between Spring AI function callbacks and MCP tools -* JSON schema generation for tool input validation -* Automatic type conversion and error handling -* Integration with Spring AI's function calling capabilities - -== Main Components - -=== McpFunctionCallback - -The `McpFunctionCallback` class implements Spring AI's `FunctionCallback` interface and provides integration between Spring AI's function calling system and MCP tools. Key features include: - -* Automatic conversion between JSON and Java objects for tool arguments -* Synchronous tool execution support -* Error handling and result formatting -* Integration with Spring AI's function calling system - -Example usage: - -[source,java] ----- -// Create an MCP client -McpSyncClient mcpClient = McpClient.using(transport) - .sync(); - -// Create a function callback for an MCP tool -Tool calculatorTool = new Tool("calculator", "Basic calculator", - Map.of("operation", "string", "a", "number", "b", "number")); -FunctionCallback callback = new McpFunctionCallback(mcpClient, calculatorTool); - -// Use the callback with Spring AI -String result = callback.call("{\"operation\": \"add\", \"a\": 2, \"b\": 3}"); ----- - -=== ToolHelper - -The `ToolHelper` utility class facilitates the integration between Spring AI's function callbacks and MCP's tool system. It provides methods for: - -* Converting Spring AI's `FunctionCallback` instances to MCP tool registrations -* Generating JSON schemas for tool input validation -* Handling error cases and result formatting - -Example usage: - -[source,java] ----- -// Convert Spring AI function callbacks to MCP tool registrations -List callbacks = List.of( - new CalculatorFunction(), - new WeatherFunction() -); -List tools = ToolHelper.toToolRegistration(callbacks); - -// Generate JSON schema for tool inputs -Map> inputTypes = Map.of( - "calculator", CalculatorInput.class, - "weather", WeatherInput.class -); -String schema = ToolHelper.generateJsonSchema(inputTypes); ----- - -==== Converting Function Callbacks to Tools - -The `ToolHelper` provides several methods to convert Spring AI function callbacks to MCP tools: - -[source,java] ----- -// Convert a single function callback -ToolRegistration tool = ToolHelper.toToolRegistration(myCallback); - -// Convert multiple callbacks -List tools = ToolHelper.toToolRegistration(callback1, callback2); - -// Convert a list of callbacks -List tools = ToolHelper.toToolRegistration(callbackList); ----- - -==== JSON Schema Generation - -The `ToolHelper` can generate JSON schemas for tool input validation: - -[source,java] ----- -// Using default ObjectMapper -String schema = ToolHelper.generateJsonSchema(inputTypes); - -// Using custom ObjectMapper -ObjectMapper mapper = new ObjectMapper(); -String schema = ToolHelper.generateJsonSchema(inputTypes, mapper); ----- - -The generated schema follows the JSON Schema Draft 2020-12 specification and: -* Validates the structure of tool inputs -* Excludes ToolContext class from schema generation -* Uses Jackson's JsonSchemaGenerator for accurate type representation - -== Usage - -To use this module, add the following dependency to your Maven project: - -[source,xml] ----- - - org.springframework.experimental - spring-ai-mcp - ----- - -Reffer to the xref:dependency-management.adoc[Dependency Management] page for more information. - -=== Example: Creating an MCP Tool Server with Spring AI Functions - -[source,java] ----- -@Configuration -class McpConfig { - - @Bean - McpServer mcpServer(List callbacks) { - // Convert Spring AI callbacks to MCP tools - List tools = ToolHelper.toToolRegistration(callbacks); - - return McpServer.using(transport) - .info("spring-ai-server", "1.0.0") - .tools(tools) - .sync(); - } - - @Bean - FunctionCallback calculatorFunction() { - return FunctionCallback.builder() - .name("calculator") - .description("Basic calculator") - .function(input -> { - // Function implementation - return result; - }) - .build(); - } -} ----- - -This configuration: -1. Creates Spring AI function callbacks -2. Converts them to MCP tools using ToolHelper -3. Registers the tools with an MCP server -4. Makes the tools available for discovery and execution by MCP clients diff --git a/mcp-docs/src/main/antora/resources/antora-resources/antora.yml b/mcp-docs/src/main/antora/resources/antora-resources/antora.yml deleted file mode 100644 index 99706e09e..000000000 --- a/mcp-docs/src/main/antora/resources/antora-resources/antora.yml +++ /dev/null @@ -1,2 +0,0 @@ -version: ${antora-component.version} -prerelease: ${antora-component.prerelease} \ No newline at end of file diff --git a/mcp-docs/src/main/javadoc/overview.html b/mcp-docs/src/main/javadoc/overview.html deleted file mode 100644 index 3a30a697b..000000000 --- a/mcp-docs/src/main/javadoc/overview.html +++ /dev/null @@ -1,33 +0,0 @@ - - - - -

- This document is the API specification for Spring AI -

-
-

- For further API reference and developer documentation, see the - - Spring AI MCP reference documentation. - That documentation contains more detailed, developer-targeted - descriptions, with conceptual overviews, definitions of terms, - and working code examples. -

-
- - \ No newline at end of file diff --git a/mcp-transport/mcp-webflux-sse-transport/README.md b/mcp-spring/mcp-spring-webflux/README.md similarity index 92% rename from mcp-transport/mcp-webflux-sse-transport/README.md rename to mcp-spring/mcp-spring-webflux/README.md index 57e4d1142..e701e41e6 100644 --- a/mcp-transport/mcp-webflux-sse-transport/README.md +++ b/mcp-spring/mcp-spring-webflux/README.md @@ -2,8 +2,8 @@ ```xml - org.springframework.experimental - mcp-webflux-sse-transport + io.modelcontextprotocol.sdk + mcp-spring-webflux ``` diff --git a/mcp-transport/mcp-webflux-sse-transport/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml similarity index 82% rename from mcp-transport/mcp-webflux-sse-transport/pom.xml rename to mcp-spring/mcp-spring-webflux/pom.xml index df23747e7..c1425c183 100644 --- a/mcp-transport/mcp-webflux-sse-transport/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -4,34 +4,34 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 ../../pom.xml - mcp-webflux-sse-transport + mcp-spring-webflux jar WebFlux implementation of the Java MCP SSE transport - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git - org.springframework.experimental + io.modelcontextprotocol.sdk mcp - 0.6.0 + 0.8.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-test - 0.6.0 + 0.8.0 test diff --git a/mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java similarity index 90% rename from mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransport.java rename to mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index a4605c657..b0dfa89c0 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.io.IOException; import java.util.function.BiConsumer; @@ -21,6 +9,11 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Disposable; @@ -32,11 +25,6 @@ import reactor.util.retry.Retry; import reactor.util.retry.Retry.RetrySignal; -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage; -import org.springframework.ai.mcp.util.Assert; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; @@ -44,7 +32,7 @@ /** * Server-Sent Events (SSE) implementation of the - * {@link org.springframework.ai.mcp.spec.McpTransport} that follows the MCP HTTP with SSE + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE * transport specification. * *

@@ -70,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); @@ -315,7 +303,7 @@ public Mono closeGracefully() { // @formatter:off } // @formatter:on /** - * Unmarshals data from a generic Object into the specified type using the configured + * Unmarshalls data from a generic Object into the specified type using the configured * ObjectMapper. * *

@@ -325,7 +313,7 @@ public Mono closeGracefully() { // @formatter:off * @param the target type to convert the data into * @param data the source object to convert * @param typeRef the TypeReference describing the target type - * @return the unmarshaled object of type T + * @return the unmarshalled object of type T * @throws IllegalArgumentException if the conversion cannot be performed */ @Override diff --git a/mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java similarity index 94% rename from mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebFluxSseServerTransport.java rename to mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java index 6126eba58..fb0b581e0 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -1,4 +1,4 @@ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.IOException; import java.time.Duration; @@ -9,16 +9,16 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.ServerMcpTransport; -import org.springframework.ai.mcp.util.Assert; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; @@ -60,7 +60,10 @@ * @author Alexandros Pappas * @see ServerMcpTransport * @see ServerSentEvent + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebFluxSseServerTransportProvider}. */ +@Deprecated public class WebFluxSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); @@ -182,16 +185,16 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { try {// @formatter:off String jsonText = objectMapper.writeValueAsString(message); ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); + .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) + .map(session -> session.id) + .toList(); if (failedSessions.isEmpty()) { logger.debug("Successfully broadcast message to all sessions"); @@ -251,7 +254,7 @@ public Mono closeGracefully() { .then(Mono.fromRunnable(() -> sessions.remove(sessionId))); }).toList())) .timeout(Duration.ofSeconds(5)) - .doOnSuccess(v -> logger.info("Graceful shutdown completed")) + .doOnSuccess(v -> logger.debug("Graceful shutdown completed")) .doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage())); } @@ -407,4 +410,4 @@ void close() { } -} +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java new file mode 100644 index 000000000..cf3eeae03 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -0,0 +1,351 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using + * Server-Sent Events (SSE). This implementation provides a bidirectional communication + * channel between MCP clients and servers using HTTP POST for client-to-server messages + * and SSE for server-to-client messages. + * + *

+ * Key features: + *

    + *
  • Implements the {@link McpServerTransportProvider} interface that allows managing + * {@link McpServerSession} instances and enabling their communication with the + * {@link McpServerTransport} abstraction.
  • + *
  • Uses WebFlux for non-blocking request handling and SSE support
  • + *
  • Maintains client sessions for reliable message delivery
  • + *
  • Supports graceful shutdown with session cleanup
  • + *
  • Thread-safe message broadcasting to multiple clients
  • + *
+ * + *

+ * The transport sets up two main endpoints: + *

    + *
  • SSE endpoint (/sse) - For establishing SSE connections with clients
  • + *
  • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
  • + *
+ * + *

+ * This implementation is thread-safe and can handle multiple concurrent client + * connections. It uses {@link ConcurrentHashMap} for session management and Project + * Reactor's non-blocking APIs for message processing and delivery. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @author Dariusz Jędrzejczyk + * @see McpServerTransport + * @see ServerSentEvent + */ +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

+ * The method: + *

    + *
  • Serializes the message to JSON
  • + *
  • Creates a server-sent event with the message data
  • + *
  • Attempts to send the event to all active sessions
  • + *
  • Tracks and reports any delivery failures
  • + *
+ * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromStream(sessions.values().stream()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // doing that. + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

+ * The shutdown process: + *

    + *
  • Marks the transport as closing to prevent new connections
  • + *
  • Closes each active session
  • + *
  • Removes closed sessions from the sessions map
  • + *
  • Times out after 5 seconds if shutdown takes too long
  • + *
+ * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpServerSession::closeGracefully) + .then(); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines two endpoints: + *

    + *
  • GET {sseEndpoint} - For establishing SSE connections
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleSseConnection(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); + + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + + logger.debug("Created new SSE connection for session: {}", sessionId); + sessions.put(sessionId, session); + + // Send initial endpoint event + logger.debug("Sending initial endpoint event to session: {}", sessionId); + sink.next(ServerSentEvent.builder() + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId) + .build()); + sink.onCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }); + }), ServerSentEvent.class); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

+ * The handler: + *

    + *
  • Deserializes the incoming JSON-RPC message
  • + *
  • Passes it through the message handler chain
  • + *
  • Returns appropriate HTTP responses based on processing results
  • + *
  • Handles various error conditions with appropriate error responses
  • + *
+ * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handleMessage(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + if (request.queryParam("sessionId").isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); + } + + McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }); + } + + private class WebFluxMcpSessionTransport implements McpServerTransport { + + private final FluxSink> sink; + + public WebFluxMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java new file mode 100644 index 000000000..2d9d055f3 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -0,0 +1,503 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; + +public class WebFluxSseIntegrationTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransportProvider mcpServerTransportProvider; + + ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); + clientBulders.put("webflux", + McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBulders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsNotifciationWithEmptyRootsList(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleHandlers(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java similarity index 61% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpAsyncClientTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 761dd6ab6..2dd587d4f 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -1,27 +1,17 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; +import java.time.Duration; + +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.ai.mcp.client.transport.WebFluxSseClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; import org.springframework.web.reactive.function.client.WebClient; /** @@ -42,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } @@ -58,4 +48,8 @@ public void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java similarity index 61% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpSyncClientTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 41153afd2..72b390ddd 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -1,27 +1,17 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; +import java.time.Duration; + +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.ai.mcp.client.transport.WebFluxSseClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; import org.springframework.web.reactive.function.client.WebClient; /** @@ -42,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } @@ -58,4 +48,8 @@ protected void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java similarity index 92% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransportTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index a75f851d7..912e04f14 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.time.Duration; import java.util.Map; @@ -22,6 +10,8 @@ import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -33,8 +23,6 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java similarity index 62% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpAsyncServerTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java index b90c9a512..b460284ee 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java @@ -1,28 +1,16 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import org.springframework.ai.mcp.server.transport.WebFluxSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.server.RouterFunctions; @@ -32,8 +20,9 @@ * * @author Christian Tzolov */ +@Deprecated @Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { +class WebFluxSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { private static final int PORT = 8181; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java new file mode 100644 index 000000000..5fa787ab6 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = 8181; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java similarity index 62% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpSyncServerTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java index 048d6398b..be2bf6c7f 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java @@ -1,28 +1,16 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import org.springframework.ai.mcp.server.transport.WebFluxSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.server.RouterFunctions; @@ -32,8 +20,9 @@ * * @author Christian Tzolov */ +@Deprecated @Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { +class WebFluxSseMcpSyncServerDeprecatecTests extends AbstractMcpSyncServerDeprecatedTests { private static final int PORT = 8182; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java new file mode 100644 index 000000000..d3672e3f3 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransportProvider transportProvider; + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return transportProvider; + } + + @Override + protected void onStart() { + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java similarity index 80% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/WebFluxSseIntegrationTests.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java index ca968cfb7..981e114c9 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp; +package io.modelcontextprotocol.server.legacy; import java.time.Duration; import java.util.List; @@ -23,6 +11,23 @@ import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -32,23 +37,6 @@ import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.transport.HttpClientSseClientTransport; -import org.springframework.ai.mcp.client.transport.WebFluxSseClientTransport; -import org.springframework.ai.mcp.server.McpServer; -import org.springframework.ai.mcp.server.McpServerFeatures; -import org.springframework.ai.mcp.server.transport.WebFluxSseServerTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.InitializeResult; -import org.springframework.ai.mcp.spec.McpSchema.Role; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.client.RestClient; @@ -100,12 +88,11 @@ public void after() { void testCreateMessageWithoutInitialization() { var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -126,12 +113,11 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { InitializeResult initResult = client.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -163,12 +149,11 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { InitializeResult initResult = client.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { assertThat(result).isNotNull(); @@ -367,13 +352,13 @@ void testToolCallSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -407,13 +392,13 @@ void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -430,7 +415,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -459,7 +444,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // Add a new tool McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + new Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); mcpServer.addTool(tool2); diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java similarity index 68% rename from mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java rename to mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java index 96690dbd8..0ab72a99f 100644 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.IOException; import java.io.InputStream; diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml similarity index 100% rename from mcp-transport/mcp-webflux-sse-transport/src/test/resources/logback.xml rename to mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml diff --git a/mcp-transport/mcp-webmvc-sse-transport/README.md b/mcp-spring/mcp-spring-webmvc/README.md similarity index 82% rename from mcp-transport/mcp-webmvc-sse-transport/README.md rename to mcp-spring/mcp-spring-webmvc/README.md index 0b73e73c2..9adf5b2ee 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/README.md +++ b/mcp-spring/mcp-spring-webmvc/README.md @@ -2,8 +2,8 @@ ```xml - org.springframework.experimental - mcp-webmvc-sse-transport + io.modelcontextprotocol.sdk + mcp-spring-webmvc ``` diff --git a/mcp-transport/mcp-webmvc-sse-transport/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml similarity index 83% rename from mcp-transport/mcp-webmvc-sse-transport/pom.xml rename to mcp-spring/mcp-spring-webmvc/pom.xml index 95c7df1ba..dc198ac14 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -4,34 +4,34 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 ../../pom.xml - mcp-webmvc-sse-transport + mcp-spring-webmvc jar Spring Web MVC implementation of the Java MCP SSE transport - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git - org.springframework.experimental + io.modelcontextprotocol.sdk mcp - 0.6.0 + 0.8.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-test - 0.6.0 + 0.8.0 test diff --git a/mcp-transport/mcp-webmvc-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java similarity index 93% rename from mcp-transport/mcp-webmvc-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebMvcSseServerTransport.java rename to mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java index cee0b48ce..23193d106 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/src/main/java/org/springframework/ai/mcp/server/transport/WebMvcSseServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java @@ -1,36 +1,25 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.time.Duration; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.ServerMcpTransport; -import org.springframework.ai.mcp.util.Assert; import org.springframework.http.HttpStatus; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; @@ -44,6 +33,9 @@ * a bridge between synchronous WebMVC operations and reactive programming patterns to * maintain compatibility with the reactive transport interface. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebMvcSseServerTransportProvider}. + * *

* Key features: *

    @@ -68,12 +60,12 @@ * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client * sessions in a thread-safe manner. Each client session is assigned a unique ID and * maintains its own SSE connection. - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport * @see RouterFunction */ +@Deprecated public class WebMvcSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class); @@ -229,6 +221,14 @@ private ServerResponse handleSseConnection(ServerRequest request) { // Send initial endpoint event try { return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); ClientSession session = new ClientSession(sessionId, sseBuilder); this.sessions.put(sessionId, session); @@ -240,7 +240,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { logger.error("Failed to poll event from session queue: {}", e.getMessage()); sseBuilder.error(e); } - }); + }, Duration.ZERO); } catch (Exception e) { logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); @@ -271,6 +271,7 @@ private ServerResponse handleMessage(ServerRequest request) { // Convert the message to a Mono, apply the handler, and block for the // response + @SuppressWarnings("unused") McpSchema.JSONRPCMessage response = Mono.just(message).transform(connectHandler).block(); return ServerResponse.ok().build(); @@ -364,7 +365,7 @@ public Mono closeGracefully() { sessions.remove(sessionId); }); - logger.info("Graceful shutdown completed"); + logger.debug("Graceful shutdown completed"); }); } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java new file mode 100644 index 000000000..65416b256 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -0,0 +1,399 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +/** + * Server-side implementation of the Model Context Protocol (MCP) transport layer using + * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides + * a bridge between synchronous WebMVC operations and reactive programming patterns to + * maintain compatibility with the reactive transport interface. + * + *

    + * Key features: + *

      + *
    • Implements bidirectional communication using HTTP POST for client-to-server + * messages and SSE for server-to-client messages
    • + *
    • Manages client sessions with unique IDs for reliable message delivery
    • + *
    • Supports graceful shutdown with proper session cleanup
    • + *
    • Provides JSON-RPC message handling through configured endpoints
    • + *
    • Includes built-in error handling and logging
    • + *
    + * + *

    + * The transport operates on two main endpoints: + *

      + *
    • {@code /sse} - The SSE endpoint where clients establish their event stream + * connection
    • + *
    • A configurable message endpoint where clients send their JSON-RPC messages via HTTP + * POST
    • + *
    + * + *

    + * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client + * sessions in a thread-safe manner. Each client session is assigned a unique ID and + * maintains its own SSE connection. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see RouterFunction + */ +public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * The message is serialized to JSON and sent as an SSE event with type "message". If + * any errors occur during sending to a particular client, they are logged but don't + * prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Initiates a graceful shutdown of the transport. This method: + *

      + *
    • Sets the closing flag to prevent new connections
    • + *
    • Closes all active SSE connections
    • + *
    • Removes all session records
    • + *
    + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()).doFirst(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + }) + .flatMap(McpServerSession::closeGracefully) + .then() + .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles two endpoints: + *
      + *
    • GET /sse - For establishing SSE connections
    • + *
    • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
    • + *
    + * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients by creating a new session and + * establishing an SSE connection. This method: + *
      + *
    • Generates a unique session ID
    • + *
    • Creates a new session with a WebMvcMcpSessionTransport
    • + *
    • Sends an initial endpoint event to inform the client where to send + * messages
    • + *
    • Maintains the session in the sessions map
    • + *
    + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response if + * the server is shutting down or the connection fails + */ + private ServerResponse handleSseConnection(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + String sessionId = UUID.randomUUID().toString(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + + // Send initial endpoint event + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + try { + sseBuilder.id(sessionId) + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles incoming JSON-RPC messages from clients. This method: + *
      + *
    • Deserializes the request body into a JSON-RPC message
    • + *
    • Processes the message through the session's handle method
    • + *
    • Returns appropriate HTTP responses based on the processing result
    • + *
    + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success (200 OK) or appropriate error status + * with error details in case of failures + */ + private ServerResponse handleMessage(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (!request.param("sessionId").isPresent()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); + } + + String sessionId = request.param("sessionId").get(); + McpServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for WebMVC compatibility + + return ServerResponse.ok().build(); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles + * the transport-level communication for a specific client session. + */ + private class WebMvcMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sseBuilder.error(e); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java similarity index 79% rename from mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseSyncServerTransportTests.java rename to mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java index 2c44c570f..c3f0e3220 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java @@ -1,29 +1,17 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.Timeout; -import org.springframework.ai.mcp.server.transport.WebMvcSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -32,8 +20,9 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; +@Deprecated @Timeout(15) -class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { +class WebMvcSseAsyncServerTransportDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java new file mode 100644 index 000000000..08d5de671 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Timeout(15) +class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private McpServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java similarity index 88% rename from mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseIntegrationTests.java rename to mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java index 308be78f4..f2b593d8d 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.time.Duration; import java.util.List; @@ -22,6 +10,20 @@ import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; @@ -29,24 +31,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.transport.HttpClientSseClientTransport; -import org.springframework.ai.mcp.server.transport.WebMvcSseServerTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.InitializeResult; -import org.springframework.ai.mcp.spec.McpSchema.Role; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.client.RestClient; @@ -60,9 +46,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; -public class WebMvcSseIntegrationTests { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcSseIntegrationTests.class); +@Deprecated +public class WebMvcSseIntegrationDeprecatedTests { private static final int PORT = 8183; @@ -411,7 +396,7 @@ void testToolCallSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -448,7 +433,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -465,7 +450,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java new file mode 100644 index 000000000..3ff755ca9 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -0,0 +1,535 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestClient; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; + +public class WebMvcSseIntegrationTests { + + private static final int PORT = 8183; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private Tomcat tomcat; + + private AnnotationConfigWebApplicationContext appContext; + + @BeforeEach + public void before() { + + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsNotifciationWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleHandlers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java similarity index 79% rename from mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseAsyncServerTransportTests.java rename to mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java index 374a5c59d..8656665ed 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/src/test/java/org/springframework/ai/mcp/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java @@ -1,29 +1,17 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.Timeout; -import org.springframework.ai.mcp.server.transport.WebMvcSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -32,8 +20,9 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; +@Deprecated @Timeout(15) -class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { +class WebMvcSseSyncServerTransportDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java new file mode 100644 index 000000000..b85bed379 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Timeout(15) +class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private WebMvcSseServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected WebMvcSseServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-transport/mcp-webmvc-sse-transport/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml similarity index 66% rename from mcp-transport/mcp-webmvc-sse-transport/src/test/resources/logback.xml rename to mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index 517af52f3..bc1140bb5 100644 --- a/mcp-transport/mcp-webmvc-sse-transport/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -9,16 +9,16 @@ - + - + - + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 212003a36..033043985 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -4,27 +4,27 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 mcp-test jar Tests for the Java MCP SDK Provides some shared test fasilities for the MCP Java SDK - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git - org.springframework.experimental + io.modelcontextprotocol.sdk mcp - 0.6.0 + 0.8.0 diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java new file mode 100644 index 000000000..cef3fb9fa --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} + * interfaces. + */ +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { + + private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); + + private final List sent = new ArrayList<>(); + + private final BiConsumer interceptor; + + public MockMcpTransport() { + this((t, msg) -> { + }); + } + + public MockMcpTransport(BiConsumer interceptor) { + this.interceptor = interceptor; + } + + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { + if (inbound.tryEmitNext(message).isFailure()) { + throw new RuntimeException("Failed to process incoming message " + message); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + sent.add(message); + interceptor.accept(this, message); + return Mono.empty(); + } + + public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { + return (JSONRPCRequest) getLastSentMessage(); + } + + public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() { + return (JSONRPCNotification) getLastSentMessage(); + } + + public McpSchema.JSONRPCMessage getLastSentMessage() { + return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; + } + + private volatile boolean connected = false; + + @Override + public Mono connect(Function, Mono> handler) { + if (connected) { + return Mono.error(new IllegalStateException("Already connected")); + } + connected = true; + return inbound.asFlux() + .flatMap(message -> Mono.just(message).transform(handler)) + .doFinally(signal -> connected = false) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + connected = false; + inbound.tryEmitComplete(); + // Wait for all subscribers to complete + return Mono.empty(); + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return new ObjectMapper().convertValue(data, typeRef); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java new file mode 100644 index 000000000..713563519 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -0,0 +1,491 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientTests { + + private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testMcpAsyncClientState() { + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); + } + + @Test + void testListPromptsWithoutInitialization() { + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + } + + @Test + void testListPrompts() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testGetPromptWithoutInitialization() { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + } + + @Test + void testGetPrompt() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); + + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); + } + + @Test + @Disabled + void testReadResource() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); + + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithSamplingCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + + @Test + void testInitializeWithAllCapabilities() { + var capabilities = ClientCapabilities.builder() + .experimental(Map.of("feature", "test")) + .roots(true) + .sampling() + .build(); + + Function> samplingHandler = request -> Mono + .just(CreateMessageResult.builder().message("test").model("test-model").build()); + + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> + + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); + + StepVerifier.create(testAllLevels).verifyComplete(); + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); + + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java new file mode 100644 index 000000000..128441f80 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -0,0 +1,449 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ListResourceTemplatesResult; +import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; +import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for MCP Client Session functionality. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpSyncClientTests { + + private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpSyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.SyncSpec builder = McpClient.sync(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + + } + + @AfterEach + void tearDown() { + onClose(); + } + + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function blockingOperation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + // This works without actually waiting but executes all the + // tasks pending execution on the VirtualTimeScheduler. + // It is possible to execute the blocking code from the operation + // because it is blocked on a dedicated Scheduler and the main + // flow is not blocked and uses the VirtualTimeScheduler. + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); + }); + } + + @Test + void testCallToolsWithoutInitialization() { + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); + } + + @Test + void testCallTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + + assertThat(toolResult).isNotNull().satisfies(result -> { + + assertThat(result.content()).hasSize(1); + + TextContent content = (TextContent) result.content().get(0); + + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyCallTimesOut(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }); + } + + @Test + void testClientSessionState() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { + + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); + } + + @Test + void testReadResourceWithoutInitialization() { + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + } + + @Test + void testReadResource() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); + + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { + + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); + } + +} diff --git a/mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java similarity index 91% rename from mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java rename to mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java index 0c7c1aa58..005d78f25 100644 --- a/mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -1,43 +1,30 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.time.Duration; import java.util.List; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.PromptMessage; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -48,7 +35,8 @@ * * @author Christian Tzolov */ -public abstract class AbstractMcpAsyncServerTests { +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -79,7 +67,8 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java new file mode 100644 index 000000000..7bcb9a8b2 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -0,0 +1,468 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServerTransportProvider createMcpTransportProvider(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java similarity index 90% rename from mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java rename to mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java index ddb827a05..c6625acaa 100644 --- a/mcp-test/src/main/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -1,40 +1,27 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.util.List; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.PromptMessage; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -45,7 +32,7 @@ * * @author Christian Tzolov */ -public abstract class AbstractMcpSyncServerTests { +public abstract class AbstractMcpSyncServerDeprecatedTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -77,7 +64,7 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java new file mode 100644 index 000000000..7846e053b --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -0,0 +1,440 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpSyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServerTransportProvider createMcpTransportProvider(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt specification must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specificaiton) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchage, roots) -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, (exchange, roots) -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/org/springframework/ai/mcp/MockMcpTransport.java b/mcp-test/src/main/java/org/springframework/ai/mcp/MockMcpTransport.java deleted file mode 100644 index 20736b1a6..000000000 --- a/mcp-test/src/main/java/org/springframework/ai/mcp/MockMcpTransport.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp; - -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCNotification; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -@SuppressWarnings("unused") -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { - - private final AtomicInteger inboundMessageCount = new AtomicInteger(0); - - private final Sinks.Many outgoing = Sinks.many().multicast().onBackpressureBuffer(); - - private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); - - private final Flux outboundView = outgoing.asFlux().cache(1); - - public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { - if (inbound.tryEmitNext(message).isFailure()) { - throw new RuntimeException("Failed to emit message " + message); - } - inboundMessageCount.incrementAndGet(); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (outgoing.tryEmitNext(message).isFailure()) { - return Mono.error(new RuntimeException("Can't emit outgoing message " + message)); - } - return Mono.empty(); - } - - public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { - return (JSONRPCRequest) outboundView.blockFirst(); - } - - public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() { - return (JSONRPCNotification) outboundView.blockFirst(); - } - - public McpSchema.JSONRPCMessage getLastSentMessage() { - return outboundView.blockFirst(); - } - - private volatile boolean connected = false; - - @Override - public Mono connect(Function, Mono> handler) { - if (connected) { - return Mono.error(new IllegalStateException("Already connected")); - } - connected = true; - return inbound.asFlux() - .publishOn(Schedulers.boundedElastic()) - .flatMap(message -> Mono.just(message).transform(handler)) - .doFinally(signal -> connected = false) - .then(); - } - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - connected = false; - outgoing.tryEmitComplete(); - inbound.tryEmitComplete(); - return Mono.empty(); - }); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); - } - -} \ No newline at end of file diff --git a/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java deleted file mode 100644 index 2579f573c..000000000 --- a/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java +++ /dev/null @@ -1,375 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptRequest; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.SubscribeRequest; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpSchema.UnsubscribeRequest; -import org.springframework.ai.mcp.spec.McpTransport; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncClient} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -public abstract class AbstractMcpAsyncClientTests { - - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - - private static final Duration TIMEOUT = Duration.ofSeconds(20); - - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - - abstract protected ClientMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(TIMEOUT) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListTools() { - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); - } - - @Test - void testPing() { - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); - } - - @Test - void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); - } - - @Test - void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); - } - - @Test - void testListResources() { - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); - } - - @Test - void testListPrompts() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testGetPrompt() { - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); - } - - @Test - void testRootsListChanged() { - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - @Disabled - void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); - } - - // @Test - void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(); - // Trigger notifications - client.sendResourcesListChanged().block(); - client.promptListChangedNotification().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) - .roots(true) - .sampling() - .build(); - - Function> samplingHandler = request -> Mono - .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java deleted file mode 100644 index c0126125d..000000000 --- a/mcp-test/src/main/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.ListResourceTemplatesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListResourcesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListToolsResult; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.SubscribeRequest; -import org.springframework.ai.mcp.spec.McpSchema.TextContent; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpSchema.UnsubscribeRequest; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for MCP Client Session functionality. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -public abstract class AbstractMcpSyncClientTests { - - private McpSyncClient mcpSyncClient; - - private static final Duration TIMEOUT = Duration.ofSeconds(10); - - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - - protected ClientMcpTransport mcpTransport; - - abstract protected ClientMcpTransport createMcpTransport(); - - abstract protected void onStart(); - - abstract protected void onClose(); - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(TIMEOUT) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - mcpSyncClient.initialize(); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListTools() { - ListToolsResult tools = mcpSyncClient.listTools(null); - - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }); - } - - @Test - void testCallTools() { - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - - assertThat(toolResult).isNotNull().satisfies(result -> { - - assertThat(result.content()).hasSize(1); - - TextContent content = (TextContent) result.content().get(0); - - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); - }); - } - - @Test - void testPing() { - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); - } - - @Test - void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - } - - @Test - void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); - } - - @Test - void testRootsListChanged() { - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); - } - - @Test - void testListResources() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - } - - @Test - void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - void testReadResource() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } - } - - @Test - void testListResourceTemplates() { - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - } - - // @Test - void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - // Trigger notifications - client.sendResourcesListChanged(); - client.promptListChangedNotification(); - client.close(); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/_SseServerTransportTests.java_ b/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/_SseServerTransportTests.java_ deleted file mode 100644 index 80f31a227..000000000 --- a/mcp-transport/mcp-webflux-sse-transport/src/test/java/org/springframework/ai/mcp/server/transport/_SseServerTransportTests.java_ +++ /dev/null @@ -1,283 +0,0 @@ -package org.springframework.ai.mcp.server.transport; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.test.web.reactive.server.WebTestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for the {@link SseServerTransport} class. - * - * @author Christian Tzolov - */ -@Timeout(15) -class SseServerTransportTests { - - private static final Logger logger = LoggerFactory.getLogger(SseServerTransportTests.class); - - private ObjectMapper objectMapper; - - private String messageEndpoint; - - private SseServerTransport transport; - - private WebTestClient webTestClient; - - @BeforeEach - void setUp() { - objectMapper = new ObjectMapper(); - messageEndpoint = "/message"; - transport = new SseServerTransport(objectMapper, messageEndpoint); - webTestClient = WebTestClient.bindToRouterFunction(transport.getRouterFunction()).build(); - } - - @Test - void constructorValidation() { - assertThatThrownBy(() -> new SseServerTransport(null, "/message")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ObjectMapper must not be null"); - - assertThatThrownBy(() -> new SseServerTransport(new ObjectMapper(), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Message endpoint must not be null"); - } - - @Test - void testSseConnectionEstablishment() { - List> events = new ArrayList<>(); - - webTestClient.get() - .uri("/sse") - .accept(MediaType.TEXT_EVENT_STREAM) - .exchange() - .expectStatus() - .isOk() - .expectHeader() - .contentTypeCompatibleWith(MediaType.TEXT_EVENT_STREAM) - .returnResult(String.class) - .getResponseBody() - .map(data -> ServerSentEvent.builder().data(data).build()) - .take(1) // Take only the initial endpoint event - .subscribe(events::add); - - // Wait a bit for the event to be received - StepVerifier.create(Mono.delay(Duration.ofMillis(500))).expectNextCount(1).verifyComplete(); - - assertThat(events).hasSize(1); - assertThat(events.get(0).data()).isEqualTo(messageEndpoint); - } - - @Test - void testMessageHandling() { - AtomicReference receivedMessage = new AtomicReference<>(); - - // Set up message handler - transport.connect(message -> { - message.doOnNext(receivedMessage::set).subscribe(); - return Mono.empty(); - }).block(); - - // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Send message to endpoint - webTestClient.post() - .uri(messageEndpoint) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(testMessage) - .exchange() - .expectStatus() - .isOk(); - - // Verify the message was received and processed - assertThat(receivedMessage.get()).isNotNull(); - McpSchema.JSONRPCRequest receivedRequest = (McpSchema.JSONRPCRequest) receivedMessage.get(); - assertThat(receivedRequest.id()).isEqualTo(testMessage.id()); - assertThat(receivedRequest.method()).isEqualTo(testMessage.method()); - } - - @Test - @Disabled("Flaky test") - void testBroadcastMessage() { - // Create test clients - int clientCount = 3; - CountDownLatch connectLatch = new CountDownLatch(clientCount); - CountDownLatch messageLatch = new CountDownLatch(clientCount); - - List>> allReceivedEvents = new ArrayList<>(); - List subscriptions = new ArrayList<>(); - - // Connect clients - for (int i = 0; i < clientCount; i++) { - List> clientEvents = new ArrayList<>(); - allReceivedEvents.add(clientEvents); - - Flux> eventStream = webTestClient.get() - .uri("/sse") - .accept(MediaType.TEXT_EVENT_STREAM) - .exchange() - .expectStatus() - .isOk() - .returnResult(String.class) - .getResponseBody() - .map(data -> ServerSentEvent.builder().data(data).build()); - - Disposable subscription = eventStream.doOnNext(event -> { - clientEvents.add(event); - if (event.event() != null && event.event().equals(SseServerTransport.ENDPOINT_EVENT_TYPE)) { - connectLatch.countDown(); - } - else if (event.event() != null && event.event().equals(SseServerTransport.MESSAGE_EVENT_TYPE)) { - messageLatch.countDown(); - } - }).subscribe(); - - subscriptions.add(subscription); - } - - // Wait for all clients to connect - try { - assertThat(connectLatch.await(5, TimeUnit.SECONDS)).isTrue(); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - - // Verify initial connections - for (List> events : allReceivedEvents) { - assertThat(events).hasSize(1); - assertThat(events.get(0).data()).isEqualTo(messageEndpoint); - } - - // Give clients time to fully establish their subscriptions - logger.debug("Waiting for subscriptions to stabilize..."); - try { - Thread.sleep(1000); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - logger.debug("Sending broadcast message to {} clients", clientCount); - - // Send broadcast message - JSONRPCRequest broadcastMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "broadcast", "broadcast-id", - Map.of("message", "Hello all!")); - - // Send the message - transport.sendMessage(broadcastMessage).block(Duration.ofSeconds(5)); - - // Wait for all clients to receive the broadcast - try { - assertThat(messageLatch.await(5, TimeUnit.SECONDS)).isTrue(); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - - // Verify each client received both messages - for (List> events : allReceivedEvents) { - assertThat(events).hasSize(2); - assertThat(events.get(0).data()).isEqualTo(messageEndpoint); - assertThat(events.get(1).data()).contains("broadcast-id"); - } - - // Cleanup - subscriptions.forEach(Disposable::dispose); - } - - @Test - void testGracefulShutdown() { - // Connect a client - Flux> eventStream = webTestClient.get() - .uri("/sse") - .accept(MediaType.TEXT_EVENT_STREAM) - .exchange() - .expectStatus() - .isOk() - .returnResult(String.class) - .getResponseBody() - .map(data -> ServerSentEvent.builder().data(data).build()); - - List> receivedEvents = new ArrayList<>(); - eventStream.subscribe(receivedEvents::add); - - // Wait for connection - StepVerifier.create(Mono.delay(Duration.ofMillis(500))).expectNextCount(1).verifyComplete(); - - // Initiate shutdown - transport.closeGracefully().block(Duration.ofSeconds(5)); - - // Verify server rejects new connections with timeout - webTestClient.get() - .uri("/sse") - .accept(MediaType.TEXT_EVENT_STREAM) - .exchange() - .expectStatus() - .isEqualTo(HttpStatus.SERVICE_UNAVAILABLE) - .expectBody(String.class) - .isEqualTo("Server is shutting down"); - - // Verify server rejects new messages with timeout - webTestClient.post() - .uri(messageEndpoint) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(""" - { - "jsonrpc": "2.0", - "method": "test", - "id": "1", - "params": {} - } - """) - .exchange() - .expectStatus() - .isEqualTo(HttpStatus.SERVICE_UNAVAILABLE) - .expectBody(String.class) - .isEqualTo("Server is shutting down"); - } - - @Test - void testInvalidMessageHandling() { - // Test invalid JSON - webTestClient.post() - .uri(messageEndpoint) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue("invalid json") - .exchange() - .expectStatus() - .isBadRequest(); - - // Test invalid message format - webTestClient.post().uri(messageEndpoint).contentType(MediaType.APPLICATION_JSON).bodyValue(""" - { - "invalid": "message" - } - """).exchange().expectStatus().isBadRequest(); - } - -} diff --git a/mcp/README.md b/mcp/README.md index 0fcffcb66..7a9ff8516 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -1,5 +1,5 @@ # Java MCP SDK Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools. +For comprehensive guides and API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview). -Find more at [Java MCP SDK](https://docs.spring.io/spring-ai-mcp/reference/mcp.html) \ No newline at end of file diff --git a/mcp/pom.xml b/mcp/pom.xml index 3febfbfb4..927be2eb7 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -4,22 +4,66 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 mcp jar Java MCP SDK Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + @@ -114,12 +158,20 @@ test + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + + jakarta.servlet jakarta.servlet-api ${jakarta.servlet.version} - provided + provided @@ -139,4 +191,4 @@ - + \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java new file mode 100644 index 000000000..9cbef0500 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -0,0 +1,797 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * The Model Context Protocol (MCP) client implementation that provides asynchronous + * communication with MCP servers using Project Reactor's Mono and Flux types. + * + *

    + * This client implements the MCP specification, enabling AI models to interact with + * external tools and resources through a standardized interface. Key features include: + *

      + *
    • Asynchronous communication using reactive programming patterns + *
    • Tool discovery and invocation for server-provided functionality + *
    • Resource access and management with URI-based addressing + *
    • Prompt template handling for standardized AI interactions + *
    • Real-time notifications for tools, resources, and prompts changes + *
    • Structured logging with configurable severity levels + *
    • Message sampling for AI model interactions + *
    + * + *

    + * The client follows a lifecycle: + *

      + *
    1. Initialization - Establishes connection and negotiates capabilities + *
    2. Normal Operation - Handles requests and notifications + *
    3. Graceful Shutdown - Ensures clean connection termination + *
    + * + *

    + * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + * @see McpClient + * @see McpSchema + * @see McpClientSession + */ +public class McpAsyncClient { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); + + private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + }; + + protected final Sinks.One initializedSink = Sinks.one(); + + private AtomicBoolean initialized = new AtomicBoolean(false); + + /** + * The max timeout to await for the client-server connection to be initialized. + */ + private final Duration initializationTimeout; + + /** + * The MCP session implementation that manages bidirectional JSON-RPC communication + * between clients and servers. + */ + private final McpClientSession mcpSession; + + /** + * Client capabilities. + */ + private final McpSchema.ClientCapabilities clientCapabilities; + + /** + * Client implementation information. + */ + private final McpSchema.Implementation clientInfo; + + /** + * Server capabilities. + */ + private McpSchema.ServerCapabilities serverCapabilities; + + /** + * Server implementation information. + */ + private McpSchema.Implementation serverInfo; + + /** + * Roots define the boundaries of where servers can operate within the filesystem, + * allowing them to understand which directories and files they have access to. + * Servers can request the list of roots from supporting clients and receive + * notifications when that list changes. + */ + private final ConcurrentHashMap roots; + + /** + * MCP provides a standardized way for servers to request LLM sampling ("completions" + * or "generations") from language models via clients. This flow allows clients to + * maintain control over model access, selection, and permissions while enabling + * servers to leverage AI capabilities—with no server API keys necessary. Servers can + * request text or image-based interactions and optionally include context from MCP + * servers in their prompts. + */ + private Function> samplingHandler; + + /** + * Client transport implementation. + */ + private final McpTransport transport; + + /** + * Supported protocol versions. + */ + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + /** + * Create a new McpAsyncClient with the given transport and session request-response + * timeout. + * @param transport the transport to use. + * @param requestTimeout the session request-response timeout. + * @param initializationTimeout the max timeout to await for the client-server + * @param features the MCP Client supported features. + */ + McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Duration initializationTimeout, + McpClientFeatures.Async features) { + + Assert.notNull(transport, "Transport must not be null"); + Assert.notNull(requestTimeout, "Request timeout must not be null"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + + this.clientInfo = features.clientInfo(); + this.clientCapabilities = features.clientCapabilities(); + this.transport = transport; + this.roots = new ConcurrentHashMap<>(features.roots()); + this.initializationTimeout = initializationTimeout; + + // Request Handlers + Map> requestHandlers = new HashMap<>(); + + // Roots List Request Handler + if (this.clientCapabilities.roots() != null) { + requestHandlers.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler()); + } + + // Sampling Handler + if (this.clientCapabilities.sampling() != null) { + if (features.samplingHandler() == null) { + throw new McpError("Sampling handler must not be null when client capabilities include sampling"); + } + this.samplingHandler = features.samplingHandler(); + requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); + } + + // Notification Handlers + Map notificationHandlers = new HashMap<>(); + + // Tools Change Notification + List, Mono>> toolsChangeConsumersFinal = new ArrayList<>(); + toolsChangeConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Tools changed: {}", notification))); + + if (!Utils.isEmpty(features.toolsChangeConsumers())) { + toolsChangeConsumersFinal.addAll(features.toolsChangeConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, + asyncToolsChangeNotificationHandler(toolsChangeConsumersFinal)); + + // Resources Change Notification + List, Mono>> resourcesChangeConsumersFinal = new ArrayList<>(); + resourcesChangeConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Resources changed: {}", notification))); + + if (!Utils.isEmpty(features.resourcesChangeConsumers())) { + resourcesChangeConsumersFinal.addAll(features.resourcesChangeConsumers()); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, + asyncResourcesChangeNotificationHandler(resourcesChangeConsumersFinal)); + + // Prompts Change Notification + List, Mono>> promptsChangeConsumersFinal = new ArrayList<>(); + promptsChangeConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Prompts changed: {}", notification))); + if (!Utils.isEmpty(features.promptsChangeConsumers())) { + promptsChangeConsumersFinal.addAll(features.promptsChangeConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, + asyncPromptsChangeNotificationHandler(promptsChangeConsumersFinal)); + + // Utility Logging Notification + List>> loggingConsumersFinal = new ArrayList<>(); + loggingConsumersFinal.add((notification) -> Mono.fromRunnable(() -> logger.debug("Logging: {}", notification))); + if (!Utils.isEmpty(features.loggingConsumers())) { + loggingConsumersFinal.addAll(features.loggingConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, + asyncLoggingNotificationHandler(loggingConsumersFinal)); + + this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Check if the client-server connection is initialized. + * @return true if the client-server connection is initialized + */ + public boolean isInitialized() { + return this.initialized.get(); + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Closes the client connection immediately. + */ + public void close() { + this.mcpSession.close(); + } + + /** + * Gracefully closes the client connection. + * @return A Mono that completes when the connection is closed + */ + public Mono closeGracefully() { + return this.mcpSession.closeGracefully(); + } + + // -------------------------- + // Initialization + // -------------------------- + /** + * The initialization phase MUST be the first interaction between client and server. + * During this phase, the client and server: + *

      + *
    • Establish protocol version compatibility
    • + *
    • Exchange and negotiate capabilities
    • + *
    • Share implementation details
    • + *
    + *
    + * The client MUST initiate this phase by sending an initialize request containing: + * The protocol version the client supports, client's capabilities and clients + * implementation information. + *

    + * The server MUST respond with its own capabilities and information. + *

    + * After successful initialization, the client MUST send an initialized notification + * to indicate it is ready to begin normal operations. + * @return the initialize result. + * @see MCP + * Initialization Spec + */ + public Mono initialize() { + + String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off + latestVersion, + this.clientCapabilities, + this.clientInfo); // @formatter:on + + Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, new TypeReference() { + }); + + return result.flatMap(initializeResult -> { + + this.serverCapabilities = initializeResult.capabilities(); + this.serverInfo = initializeResult.serverInfo(); + + logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", + initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), + initializeResult.instructions()); + + if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { + return Mono.error(new McpError( + "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); + } + + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { + this.initialized.set(true); + this.initializedSink.tryEmitValue(initializeResult); + }).thenReturn(initializeResult); + }); + } + + /** + * Utility method to handle the common pattern of checking initialization before + * executing an operation. + * @param The type of the result Mono + * @param actionName The action to perform if the client is initialized + * @param operation The operation to execute if the client is initialized + * @return A Mono that completes with the result of the operation + */ + private Mono withInitializationCheck(String actionName, + Function> operation) { + return this.initializedSink.asMono() + .timeout(this.initializationTimeout) + .onErrorResume(TimeoutException.class, + ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) + .flatMap(operation); + } + + // -------------------------- + // Basic Utilites + // -------------------------- + + /** + * Sends a ping request to the server. + * @return A Mono that completes with the server's ping response + */ + public Mono ping() { + return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { + })); + } + + // -------------------------- + // Roots + // -------------------------- + /** + * Adds a new root to the client's root list. + * @param root The root to add. + * @return A Mono that completes when the root is added and notifications are sent. + */ + public Mono addRoot(Root root) { + + if (root == null) { + return Mono.error(new McpError("Root must not be null")); + } + + if (this.clientCapabilities.roots() == null) { + return Mono.error(new McpError("Client must be configured with roots capabilities")); + } + + if (this.roots.containsKey(root.uri())) { + return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists")); + } + + this.roots.put(root.uri(), root); + + logger.debug("Added root: {}", root); + + if (this.clientCapabilities.roots().listChanged()) { + if (this.isInitialized()) { + return this.rootsListChangedNotification(); + } + else { + logger.warn("Client is not initialized, ignore sending a roots list changed notification"); + } + } + return Mono.empty(); + } + + /** + * Removes a root from the client's root list. + * @param rootUri The URI of the root to remove. + * @return A Mono that completes when the root is removed and notifications are sent. + */ + public Mono removeRoot(String rootUri) { + + if (rootUri == null) { + return Mono.error(new McpError("Root uri must not be null")); + } + + if (this.clientCapabilities.roots() == null) { + return Mono.error(new McpError("Client must be configured with roots capabilities")); + } + + Root removed = this.roots.remove(rootUri); + + if (removed != null) { + logger.debug("Removed Root: {}", rootUri); + if (this.clientCapabilities.roots().listChanged()) { + if (this.isInitialized()) { + return this.rootsListChangedNotification(); + } + else { + logger.warn("Client is not initialized, ignore sending a roots list changed notification"); + } + + } + return Mono.empty(); + } + return Mono.error(new McpError("Root with uri '" + rootUri + "' not found")); + } + + /** + * Manually sends a roots/list_changed notification. The addRoot and removeRoot + * methods automatically send the roots/list_changed notification if the client is in + * an initialized state. + * @return A Mono that completes when the notification is sent. + */ + public Mono rootsListChangedNotification() { + return this.withInitializationCheck("sending roots list changed notification", + initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); + } + + private RequestHandler rootsListRequestHandler() { + return params -> { + @SuppressWarnings("unused") + McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, + new TypeReference() { + }); + + List roots = this.roots.values().stream().toList(); + + return Mono.just(new McpSchema.ListRootsResult(roots)); + }; + } + + // -------------------------- + // Sampling + // -------------------------- + private RequestHandler samplingCreateMessageHandler() { + return params -> { + McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, + new TypeReference() { + }); + + return this.samplingHandler.apply(request); + }; + } + + // -------------------------- + // Tools + // -------------------------- + private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Calls a tool provided by the server. Tools enable servers to expose executable + * functionality that can interact with external systems, perform computations, and + * take actions in the real world. + * @param callToolRequest The request containing the tool name and input parameters. + * @return A Mono that emits the result of the tool call, including the output and any + * errors. + * @see McpSchema.CallToolRequest + * @see McpSchema.CallToolResult + * @see #listTools() + */ + public Mono callTool(McpSchema.CallToolRequest callToolRequest) { + return this.withInitializationCheck("calling tools", initializedResult -> { + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server does not provide tools capability")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + }); + } + + /** + * Retrieves the list of all tools provided by the server. + * @return A Mono that emits the list of tools result. + */ + public Mono listTools() { + return this.listTools(null); + } + + /** + * Retrieves a paginated list of tools provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of tools result + */ + public Mono listTools(String cursor) { + return this.withInitializationCheck("listing tools", initializedResult -> { + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server does not provide tools capability")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF); + }); + } + + private NotificationHandler asyncToolsChangeNotificationHandler( + List, Mono>> toolsChangeConsumers) { + // TODO: params are not used yet + return params -> this.listTools() + .flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers) + .flatMap(consumer -> consumer.apply(listToolsResult.tools())) + .onErrorResume(error -> { + logger.error("Error handling tools list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // -------------------------- + // Resources + // -------------------------- + + private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all resources provided by the server. Resources represent any + * kind of UTF-8 encoded data that an MCP server makes available to clients, such as + * database records, API responses, log files, and more. + * @return A Mono that completes with the list of resources result. + * @see McpSchema.ListResourcesResult + * @see #readResource(McpSchema.Resource) + */ + public Mono listResources() { + return this.listResources(null); + } + + /** + * Retrieves a paginated list of resources provided by the server. Resources represent + * any kind of UTF-8 encoded data that an MCP server makes available to clients, such + * as database records, API responses, log files, and more. + * @param cursor Optional pagination cursor from a previous list request. + * @return A Mono that completes with the list of resources result. + * @see McpSchema.ListResourcesResult + * @see #readResource(McpSchema.Resource) + */ + public Mono listResources(String cursor) { + return this.withInitializationCheck("listing resources", initializedResult -> { + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server does not provide the resources capability")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCES_RESULT_TYPE_REF); + }); + } + + /** + * Reads the content of a specific resource identified by the provided Resource + * object. This method fetches the actual data that the resource represents. + * @param resource The resource to read, containing the URI that identifies the + * resource. + * @return A Mono that completes with the resource content. + * @see McpSchema.Resource + * @see McpSchema.ReadResourceResult + */ + public Mono readResource(McpSchema.Resource resource) { + return this.readResource(new McpSchema.ReadResourceRequest(resource.uri())); + } + + /** + * Reads the content of a specific resource identified by the provided request. This + * method fetches the actual data that the resource represents. + * @param readResourceRequest The request containing the URI of the resource to read + * @return A Mono that completes with the resource content. + * @see McpSchema.ReadResourceRequest + * @see McpSchema.ReadResourceResult + */ + public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { + return this.withInitializationCheck("reading resources", initializedResult -> { + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server does not provide the resources capability")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, + READ_RESOURCE_RESULT_TYPE_REF); + }); + } + + /** + * Retrieves the list of all resource templates provided by the server. Resource + * templates allow servers to expose parameterized resources using URI templates, + * enabling dynamic resource access based on variable parameters. + * @return A Mono that completes with the list of resource templates result. + * @see McpSchema.ListResourceTemplatesResult + */ + public Mono listResourceTemplates() { + return this.listResourceTemplates(null); + } + + /** + * Retrieves a paginated list of resource templates provided by the server. Resource + * templates allow servers to expose parameterized resources using URI templates, + * enabling dynamic resource access based on variable parameters. + * @param cursor Optional pagination cursor from a previous list request. + * @return A Mono that completes with the list of resource templates result. + * @see McpSchema.ListResourceTemplatesResult + */ + public Mono listResourceTemplates(String cursor) { + return this.withInitializationCheck("listing resource templates", initializedResult -> { + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server does not provide the resources capability")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, + new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); + }); + } + + /** + * Subscribes to changes in a specific resource. When the resource changes on the + * server, the client will receive notifications through the resources change + * notification handler. + * @param subscribeRequest The subscribe request containing the URI of the resource. + * @return A Mono that completes when the subscription is complete. + * @see McpSchema.SubscribeRequest + * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) + */ + public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { + return this.withInitializationCheck("subscribing to resources", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); + } + + /** + * Cancels an existing subscription to a resource. After unsubscribing, the client + * will no longer receive notifications when the resource changes. + * @param unsubscribeRequest The unsubscribe request containing the URI of the + * resource. + * @return A Mono that completes when the unsubscription is complete. + * @see McpSchema.UnsubscribeRequest + * @see #subscribeResource(McpSchema.SubscribeRequest) + */ + public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { + return this.withInitializationCheck("unsubscribing from resources", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); + } + + private NotificationHandler asyncResourcesChangeNotificationHandler( + List, Mono>> resourcesChangeConsumers) { + return params -> listResources().flatMap(listResourcesResult -> Flux.fromIterable(resourcesChangeConsumers) + .flatMap(consumer -> consumer.apply(listResourcesResult.resources())) + .onErrorResume(error -> { + logger.error("Error handling resources list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // -------------------------- + // Prompts + // -------------------------- + private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all prompts provided by the server. + * @return A Mono that completes with the list of prompts result. + * @see McpSchema.ListPromptsResult + * @see #getPrompt(GetPromptRequest) + */ + public Mono listPrompts() { + return this.listPrompts(null); + } + + /** + * Retrieves a paginated list of prompts provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that completes with the list of prompts result. + * @see McpSchema.ListPromptsResult + * @see #getPrompt(GetPromptRequest) + */ + public Mono listPrompts(String cursor) { + return this.withInitializationCheck("listing prompts", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); + } + + /** + * Retrieves a specific prompt by its ID. This provides the complete prompt template + * including all parameters and instructions for generating AI content. + * @param getPromptRequest The request containing the ID of the prompt to retrieve. + * @return A Mono that completes with the prompt result. + * @see McpSchema.GetPromptRequest + * @see McpSchema.GetPromptResult + * @see #listPrompts() + */ + public Mono getPrompt(GetPromptRequest getPromptRequest) { + return this.withInitializationCheck("getting prompts", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); + } + + private NotificationHandler asyncPromptsChangeNotificationHandler( + List, Mono>> promptsChangeConsumers) { + return params -> listPrompts().flatMap(listPromptsResult -> Flux.fromIterable(promptsChangeConsumers) + .flatMap(consumer -> consumer.apply(listPromptsResult.prompts())) + .onErrorResume(error -> { + logger.error("Error handling prompts list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // -------------------------- + // Logging + // -------------------------- + private NotificationHandler asyncLoggingNotificationHandler( + List>> loggingConsumers) { + + return params -> { + McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, + new TypeReference() { + }); + + return Flux.fromIterable(loggingConsumers) + .flatMap(consumer -> consumer.apply(loggingMessageNotification)) + .then(); + }; + } + + /** + * Sets the minimum logging level for messages received from the server. The client + * will only receive log messages at or above the specified severity level. + * @param loggingLevel The minimum logging level to receive. + * @return A Mono that completes when the logging level is set. + * @see McpSchema.LoggingLevel + */ + public Mono setLoggingLevel(LoggingLevel loggingLevel) { + if (loggingLevel == null) { + return Mono.error(new McpError("Logging level must not be null")); + } + + return this.withInitializationCheck("setting logging level", initializedResult -> { + String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { + }); + Map params = Map.of("level", levelName); + return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); + }); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java similarity index 67% rename from mcp/src/main/java/org/springframework/ai/mcp/client/McpClient.java rename to mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index e11a9a713..9c5f7b015 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; import java.time.Duration; import java.util.ArrayList; @@ -24,18 +12,18 @@ import java.util.function.Consumer; import java.util.function.Function; +import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.Implementation; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.util.Assert; - /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -126,11 +114,31 @@ public interface McpClient { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpClientTransport)} */ + @Deprecated static SyncSpec sync(ClientMcpTransport transport) { return new SyncSpec(transport); } + /** + * Start building a synchronous MCP client with the specified transport layer. The + * synchronous MCP client provides blocking operations. Synchronous clients wait for + * each operation to complete before returning, making them simpler to use but + * potentially less performant for concurrent operations. The transport layer handles + * the low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). + * @param transport The transport layer implementation for MCP communication. Common + * implementations include {@code StdioClientTransport} for stdio-based communication + * and {@code SseClientTransport} for SSE-based communication. + * @return A new builder instance for configuring the client + * @throws IllegalArgumentException if transport is null + */ + static SyncSpec sync(McpClientTransport transport) { + return new SyncSpec(transport); + } + /** * Start building an asynchronous MCP client with the specified transport layer. The * asynchronous MCP client provides non-blocking operations. Asynchronous clients @@ -143,265 +151,29 @@ static SyncSpec sync(ClientMcpTransport transport) { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpClientTransport)} */ + @Deprecated static AsyncSpec async(ClientMcpTransport transport) { return new AsyncSpec(transport); } /** - * Start building an MCP client with the specified transport layer. The transport - * layer handles the low-level communication between client and server using protocols - * like stdio or Server-Sent Events (SSE). + * Start building an asynchronous MCP client with the specified transport layer. The + * asynchronous MCP client provides non-blocking operations. Asynchronous clients + * return reactive primitives (Mono/Flux) immediately, allowing for concurrent + * operations and reactive programming patterns. The transport layer handles the + * low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). * @param transport The transport layer implementation for MCP communication. Common * implementations include {@code StdioClientTransport} for stdio-based communication * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null - * @deprecated Use {@link #sync(ClientMcpTransport)} or - * {@link #async(ClientMcpTransport)} specification builder to configure the client - * and build an instance. */ - @Deprecated - public static Builder using(ClientMcpTransport transport) { - return new Builder(transport); - } - - /** - * Builder class for creating and configuring MCP clients. This class follows the - * builder pattern to provide a fluent API for setting up clients with custom - * configurations. - * - *

    - * The builder supports configuration of: - *

      - *
    • Transport layer for client-server communication - *
    • Request timeouts for operation boundaries - *
    • Client capabilities for feature negotiation - *
    • Client implementation details for version tracking - *
    • Root URIs for resource access - *
    • Change notification handlers for tools, resources, and prompts - *
    • Custom message sampling logic - *
    - * - * @deprecated Use {@link #sync(ClientMcpTransport)} or - * {@link #async(ClientMcpTransport)} specification builder to instantiate an - * instance. - */ - @Deprecated - public static class Builder { - - private final ClientMcpTransport transport; - - private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout - - private ClientCapabilities capabilities; - - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); - - private Map roots = new HashMap<>(); - - private List>> toolsChangeConsumers = new ArrayList<>(); - - private List>> resourcesChangeConsumers = new ArrayList<>(); - - private List>> promptsChangeConsumers = new ArrayList<>(); - - private List> loggingConsumers = new ArrayList<>(); - - private Function samplingHandler; - - private Builder(ClientMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; - } - - /** - * Sets the duration to wait for server responses before timing out requests. This - * timeout applies to all requests made through the client, including tool calls, - * resource access, and prompt operations. - * @param requestTimeout The duration to wait before timing out requests. Must not - * be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if requestTimeout is null - */ - public Builder requestTimeout(Duration requestTimeout) { - Assert.notNull(requestTimeout, "Request timeout must not be null"); - this.requestTimeout = requestTimeout; - return this; - } - - /** - * Sets the client capabilities that will be advertised to the server during - * connection initialization. Capabilities define what features the client - * supports, such as tool execution, resource access, and prompt handling. - * @param capabilities The client capabilities configuration. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if capabilities is null - */ - public Builder capabilities(ClientCapabilities capabilities) { - Assert.notNull(capabilities, "Capabilities must not be null"); - this.capabilities = capabilities; - return this; - } - - /** - * Sets the client implementation information that will be shared with the server - * during connection initialization. This helps with version compatibility and - * debugging. - * @param clientInfo The client implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if clientInfo is null - */ - public Builder clientInfo(Implementation clientInfo) { - Assert.notNull(clientInfo, "Client info must not be null"); - this.clientInfo = clientInfo; - return this; - } - - /** - * Sets the root URIs that this client can access. Roots define the base URIs for - * resources that the client can request from the server. For example, a root - * might be "file://workspace" for accessing workspace files. - * @param roots A list of root definitions. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if roots is null - */ - public Builder roots(List roots) { - Assert.notNull(roots, "Roots must not be null"); - for (Root root : roots) { - this.roots.put(root.uri(), root); - } - return this; - } - - /** - * Sets the root URIs that this client can access, using a varargs parameter for - * convenience. This is an alternative to {@link #roots(List)}. - * @param roots An array of root definitions. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if roots is null - * @see #roots(List) - */ - public Builder roots(Root... roots) { - Assert.notNull(roots, "Roots must not be null"); - for (Root root : roots) { - this.roots.put(root.uri(), root); - } - return this; - } - - /** - * Sets a custom sampling handler for processing message creation requests. The - * sampling handler can modify or validate messages before they are sent to the - * server, enabling custom processing logic. - * @param samplingHandler A function that processes message requests and returns - * results. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if samplingHandler is null - */ - public Builder sampling(Function samplingHandler) { - Assert.notNull(samplingHandler, "Sampling handler must not be null"); - this.samplingHandler = samplingHandler; - return this; - } - - /** - * Adds a consumer to be notified when the available tools change. This allows the - * client to react to changes in the server's tool capabilities, such as tools - * being added or removed. - * @param toolsChangeConsumer A consumer that receives the updated list of - * available tools. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolsChangeConsumer is null - */ - public Builder toolsChangeConsumer(Consumer> toolsChangeConsumer) { - Assert.notNull(toolsChangeConsumer, "Tools change consumer must not be null"); - this.toolsChangeConsumers.add(toolsChangeConsumer); - return this; - } - - /** - * Adds a consumer to be notified when the available resources change. This allows - * the client to react to changes in the server's resource availability, such as - * files being added or removed. - * @param resourcesChangeConsumer A consumer that receives the updated list of - * available resources. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourcesChangeConsumer is null - */ - public Builder resourcesChangeConsumer(Consumer> resourcesChangeConsumer) { - Assert.notNull(resourcesChangeConsumer, "Resources change consumer must not be null"); - this.resourcesChangeConsumers.add(resourcesChangeConsumer); - return this; - } - - /** - * Adds a consumer to be notified when the available prompts change. This allows - * the client to react to changes in the server's prompt templates, such as new - * templates being added or existing ones being modified. - * @param promptsChangeConsumer A consumer that receives the updated list of - * available prompts. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if promptsChangeConsumer is null - */ - public Builder promptsChangeConsumer(Consumer> promptsChangeConsumer) { - Assert.notNull(promptsChangeConsumer, "Prompts change consumer must not be null"); - this.promptsChangeConsumers.add(promptsChangeConsumer); - return this; - } - - /** - * Adds a consumer to be notified when logging messages are received from the - * server. This allows the client to react to log messages, such as warnings or - * errors, that are sent by the server. - * @param loggingConsumer A consumer that receives logging messages. Must not be - * null. - * @return This builder instance for method chaining - */ - public Builder loggingConsumer(Consumer loggingConsumer) { - Assert.notNull(loggingConsumer, "Logging consumer must not be null"); - this.loggingConsumers.add(loggingConsumer); - return this; - } - - /** - * Adds multiple consumers to be notified when logging messages are received from - * the server. This allows the client to react to log messages, such as warnings - * or errors, that are sent by the server. - * @param loggingConsumers A list of consumers that receive logging messages. Must - * not be null. - * @return This builder instance for method chaining - */ - public Builder loggingConsumers(List> loggingConsumers) { - Assert.notNull(loggingConsumers, "Logging consumers must not be null"); - this.loggingConsumers.addAll(loggingConsumers); - return this; - } - - /** - * Builds a synchronous MCP client that provides blocking operations. Synchronous - * clients wait for each operation to complete before returning, making them - * simpler to use but potentially less performant for concurrent operations. - * @return A new instance of {@link McpSyncClient} configured with this builder's - * settings - */ - public McpSyncClient sync() { - return new McpSyncClient(async()); - } - - /** - * Builds an asynchronous MCP client that provides non-blocking operations. - * Asynchronous clients return CompletableFuture objects immediately, allowing for - * concurrent operations and reactive programming patterns. - * @return A new instance of {@link McpAsyncClient} configured with this builder's - * settings - */ - public McpAsyncClient async() { - return new McpAsyncClient(transport, requestTimeout, clientInfo, capabilities, roots, toolsChangeConsumers, - resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, samplingHandler); - } - + static AsyncSpec async(McpClientTransport transport) { + return new AsyncSpec(transport); } /** @@ -426,9 +198,11 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); private final Map roots = new HashMap<>(); @@ -462,6 +236,18 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public SyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -623,7 +409,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, asyncFeatures)); + return new McpSyncClient( + new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); } } @@ -650,6 +437,8 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); @@ -686,6 +475,18 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public AsyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -843,7 +644,7 @@ public AsyncSpec loggingConsumers( * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { - return new McpAsyncClient(this.transport, this.requestTimeout, + return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler)); diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java similarity index 97% rename from mcp/src/main/java/org/springframework/ai/mcp/client/McpClientFeatures.java rename to mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 05be98de4..284b93f88 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -1,4 +1,8 @@ -package org.springframework.ai.mcp.client; +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; import java.util.ArrayList; import java.util.HashMap; @@ -8,13 +12,12 @@ import java.util.function.Consumer; import java.util.function.Function; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.util.Assert; -import org.springframework.ai.mcp.util.Utils; - /** * Representation of features and capabilities for Model Context Protocol (MCP) clients. * This class provides two record types for managing client features: diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java similarity index 85% rename from mcp/src/main/java/org/springframework/ai/mcp/client/McpSyncClient.java rename to mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 1d008ab94..ec0a0dfdb 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -1,34 +1,21 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; import java.time.Duration; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptRequest; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.ListPromptsResult; -import org.springframework.ai.mcp.util.Assert; - /** * A synchronous client implementation for the Model Context Protocol (MCP) that wraps an * {@link McpAsyncClient} to provide blocking operations. @@ -79,7 +66,8 @@ public class McpSyncClient implements AutoCloseable { * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. - * @deprecated Use {@link McpClient#sync(ClientMcpTransport)} to obtain an instance. + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpClient#sync(McpClientTransport)} to obtain an instance. */ @Deprecated // TODO make the constructor package private post-deprecation @@ -297,14 +285,6 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { return this.delegate.listResourceTemplates().block(); } - /** - * List Changed Notification. When the list of available resources changes, servers - * that declared the listChanged capability SHOULD send a notification: - */ - public void sendResourcesListChanged() { - this.delegate.sendResourcesListChanged().block(); - } - /** * Subscriptions. The protocol supports optional subscriptions to resource changes. * Clients can subscribe to specific resources and receive notifications when they @@ -342,15 +322,6 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { return this.delegate.getPrompt(getPromptRequest).block(); } - /** - * (Server) An optional notification from the server to the client, informing it that - * the list of prompts it offers has changed. This may be issued by servers without - * any previous subscription from the client. - */ - public void promptListChangedNotification() { - this.delegate.promptListChangedNotification().block(); - } - /** * Client can set the minimum logging level it wants to receive from the server. * @param loggingLevel the min logging level diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java similarity index 91% rename from mcp/src/main/java/org/springframework/ai/mcp/client/transport/FlowSseClient.java rename to mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index 4b6191d9f..7fc679937 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.net.URI; import java.net.http.HttpClient; diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java similarity index 86% rename from mcp/src/main/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransport.java rename to mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 68f103286..ca1b0e87a 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -1,33 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.mcp.client.transport.FlowSseClient.SseEvent; -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage; -import org.springframework.ai.mcp.util.Assert; - -import reactor.core.publisher.Mono; +package io.modelcontextprotocol.client.transport; import java.io.IOException; import java.net.URI; @@ -41,9 +15,21 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + /** * Server-Sent Events (SSE) implementation of the - * {@link org.springframework.ai.mcp.spec.McpTransport} that follows the MCP HTTP with SSE + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE * transport specification, using Java's HttpClient. * *

    @@ -65,10 +51,10 @@ * * * @author Christian Tzolov - * @see org.springframework.ai.mcp.spec.McpTransport - * @see org.springframework.ai.mcp.spec.ClientMcpTransport + * @see io.modelcontextprotocol.spec.McpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport */ -public class HttpClientSseClientTransport implements ClientMcpTransport { +public class HttpClientSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); @@ -100,7 +86,7 @@ public class HttpClientSseClientTransport implements ClientMcpTransport { private volatile boolean isClosing = false; /** Latch for coordinating endpoint discovery */ - private CountDownLatch closeLatch = new CountDownLatch(1); + private final CountDownLatch closeLatch = new CountDownLatch(1); /** Holds the discovered message endpoint URL */ private final AtomicReference messageEndpoint = new AtomicReference<>(); @@ -231,7 +217,8 @@ public Mono sendMessage(JSONRPCMessage message) { return Mono.fromFuture( httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200) { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { logger.error("Error sending message: {}", response.statusCode()); } })); diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/ServerParameters.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java similarity index 83% rename from mcp/src/main/java/org/springframework/ai/mcp/client/transport/ServerParameters.java rename to mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java index a5c685812..25a02279f 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/ServerParameters.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.util.ArrayList; import java.util.Arrays; @@ -25,8 +13,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; - -import org.springframework.ai.mcp.util.Assert; +import io.modelcontextprotocol.util.Assert; /** * Server parameters for stdio client. diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java similarity index 89% rename from mcp/src/main/java/org/springframework/ai/mcp/client/transport/StdioClientTransport.java rename to mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 74d001a5e..f9a97849f 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.io.BufferedReader; import java.io.IOException; @@ -23,13 +11,16 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -38,11 +29,6 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage; -import org.springframework.ai.mcp.util.Assert; - /** * Implementation of the MCP Stdio transport that communicates with a server process using * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC @@ -51,7 +37,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements ClientMcpTransport { +public class StdioClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); @@ -81,7 +67,7 @@ public class StdioClientTransport implements ClientMcpTransport { private volatile boolean isClosing = false; // visible for tests - private Consumer errorHandler = error -> logger.error("Error received: {}", error); + private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); /** * Creates a new StdioClientTransport with the specified parameters and default @@ -177,8 +163,8 @@ protected ProcessBuilder getProcessBuilder() { *

    * @param errorHandler a consumer that processes error messages */ - public void setErrorHandler(Consumer errorHandler) { - this.errorHandler = errorHandler; + public void setStdErrorHandler(Consumer errorHandler) { + this.stdErrorHandler = errorHandler; } /** @@ -205,7 +191,6 @@ private void startErrorProcessing() { String line; while (!isClosing && (line = processErrorReader.readLine()) != null) { try { - logger.error("Received error line: {}", line); if (!this.errorSink.tryEmitNext(line).isSuccess()) { if (!isClosing) { logger.error("Failed to emit error message"); @@ -243,7 +228,7 @@ private void handleIncomingMessages(Function, Mono { - this.errorHandler.accept(e); + this.stdErrorHandler.accept(e); }); } @@ -367,14 +352,15 @@ public Mono closeGracefully() { // Give a short time for any pending messages to be processed return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromFuture(() -> { - logger.info("Sending TERM to process"); + })).then(Mono.defer(() -> { + logger.debug("Sending TERM to process"); if (this.process != null) { this.process.destroy(); - return process.onExit(); + return Mono.fromFuture(process.onExit()); } else { - return CompletableFuture.failedFuture(new RuntimeException("Process not started")); + logger.warn("Process not started"); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { @@ -388,7 +374,7 @@ public Mono closeGracefully() { errorScheduler.dispose(); outboundScheduler.dispose(); - logger.info("Graceful shutdown completed"); + logger.debug("Graceful shutdown completed"); } catch (Exception e) { logger.error("Error during graceful shutdown", e); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java new file mode 100644 index 000000000..ef69539ad --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -0,0 +1,1510 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * The Model Context Protocol (MCP) server implementation that provides asynchronous + * communication using Project Reactor's Mono and Flux types. + * + *

    + * This server implements the MCP specification, enabling AI models to expose tools, + * resources, and prompts through a standardized interface. Key features include: + *

      + *
    • Asynchronous communication using reactive programming patterns + *
    • Dynamic tool registration and management + *
    • Resource handling with URI-based addressing + *
    • Prompt template management + *
    • Real-time client notifications for state changes + *
    • Structured logging with configurable severity levels + *
    • Support for client-side AI model sampling + *
    + * + *

    + * The server follows a lifecycle: + *

      + *
    1. Initialization - Accepts client connections and negotiates capabilities + *
    2. Normal Operation - Handles client requests and sends notifications + *
    3. Graceful Shutdown - Ensures clean connection termination + *
    + * + *

    + * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + *

    + * The server supports runtime modification of its capabilities through methods like + * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying + * connected clients of changes when configured to do so. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpServer + * @see McpSchema + * @see McpClientSession + */ +public class McpAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); + + private final McpAsyncServer delegate; + + McpAsyncServer() { + this.delegate = null; + } + + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransport The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + * @deprecated This constructor will beremoved in 0.9.0. Use + * {@link #McpAsyncServer(McpServerTransportProvider, ObjectMapper, McpServerFeatures.Async)} + * instead. + */ + @Deprecated + McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { + this.delegate = new LegacyAsyncServer(mcpTransport, features); + } + + /** + * Create a new McpAsyncServer with the given transport provider and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. + * @param features The MCP server supported features. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.delegate.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.delegate.getServerInfo(); + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientCapabilities()}. + */ + @Deprecated + public ClientCapabilities getClientCapabilities() { + return this.delegate.getClientCapabilities(); + } + + /** + * Get the client implementation information. + * @return The client implementation details + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientInfo()}. + */ + @Deprecated + public McpSchema.Implementation getClientInfo() { + return this.delegate.getClientInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.delegate.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.delegate.close(); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots()}. + */ + @Deprecated + public Mono listRoots() { + return this.delegate.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots(String)}. + */ + @Deprecated + public Mono listRoots(String cursor) { + return this.delegate.listRoots(cursor); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool registration at runtime. + * @param toolRegistration The tool registration to add + * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. + */ + @Deprecated + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + return this.delegate.addTool(toolRegistration); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + return this.delegate.addTool(toolSpecification); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + return this.delegate.removeTool(toolName); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.delegate.notifyToolsListChanged(); + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. + */ + @Deprecated + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + return this.delegate.addResource(resourceHandler); + } + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + return this.delegate.addResource(resourceHandler); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + return this.delegate.removeResource(resourceUri); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.delegate.notifyResourcesListChanged(); + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptRegistration The prompt handler to add + * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. + */ + @Deprecated + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + return this.delegate.addPrompt(promptRegistration); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + return this.delegate.addPrompt(promptSpecification); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + return this.delegate.removePrompt(promptName); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.delegate.notifyPromptsListChanged(); + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + return this.delegate.loggingNotification(loggingMessageNotification); + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. + */ + @Deprecated + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return this.delegate.createMessage(createMessageRequest); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.delegate.setProtocolVersions(protocolVersions); + } + + private static class AsyncServerImpl extends McpAsyncServer { + + private final McpServerTransportProvider mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, + roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", + roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider + .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }); + } + + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + @Override + @Deprecated + public ClientCapabilities getClientCapabilities() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + @Override + @Deprecated + public McpSchema.Implementation getClientInfo() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + @Override + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + @Override + public void close() { + this.mcpTransportProvider.close(); + } + + @Override + @Deprecated + public Mono listRoots() { + return this.listRoots(null); + } + + @Override + @Deprecated + public Mono listRoots(String cursor) { + return Mono.error(new RuntimeException("Not implemented")); + } + + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + @Override + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + return this.addTool(toolRegistration.toSpecification()); + } + + @Override + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + @Override + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpServerSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + return this.addResource(resourceHandler.toSpecification()); + } + + @Override + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + @Override + public Mono notifyResourcesListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private McpServerSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); + if (specification != null) { + return specification.readHandler().apply(exchange, resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error(new McpError( + "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + return this.addPrompt(promptRegistration.toSpecification()); + } + + @Override + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + @Override + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpServerSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + @Override + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.objectMapper.convertValue(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + private McpServerSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + + @Override + @Deprecated + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return Mono.error(new RuntimeException("Not implemented")); + } + + @Override + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + } + + private static final class LegacyAsyncServer extends McpAsyncServer { + + /** + * The MCP session implementation that manages bidirectional JSON-RPC + * communication between clients and servers. + */ + private final McpClientSession mcpSession; + + private final ServerMcpTransport transport; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + /** + * Thread-safe list of tool handlers that can be modified at runtime. + */ + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + /** + * Supported protocol versions. + */ + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransport The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + */ + LegacyAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { + + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); + + List, Mono>> rootsChangeHandlers = features + .rootsChangeConsumers(); + + List, Mono>> rootsChangeConsumers = rootsChangeHandlers.stream() + .map(handler -> (Function, Mono>) (roots) -> handler.apply(null, roots)) + .toList(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + this.transport = mcpTransport; + this.mcpSession = new McpClientSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, + notificationHandlers); + } + + @Override + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private McpClientSession.RequestHandler asyncInitializeRequestHandler() { + return params -> { + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + this.clientCapabilities = initializeRequest.capabilities(); + this.clientInfo = initializeRequest.clientInfo(); + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }; + } + + /** + * Get the server capabilities that define the supported features and + * functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Get the client capabilities that define the supported features and + * functionality. + * @return The client capabilities + */ + public ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpSession.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpSession.close(); + } + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + + private McpClientSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool registration at runtime. + * @param toolRegistration The tool registration to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + if (toolRegistration == null) { + return Mono.error(new McpError("Tool registration must not be null")); + } + if (toolRegistration.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolRegistration.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + } + + this.tools.add(toolRegistration.toSpecification()); + logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler toolsListRequestHandler() { + return params -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpClientSession.RequestHandler toolsCallRequestHandler() { + return params -> { + McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + Optional toolRegistration = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolRegistration.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolRegistration.map(tool -> tool.call().apply(null, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + if (resourceHandler == null || resourceHandler.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceHandler.resource().uri(), + resourceHandler.toSpecification()) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler resourcesListRequestHandler() { + return params -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpClientSession.RequestHandler resourceTemplateListRequestHandler() { + return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private McpClientSession.RequestHandler resourcesReadRequestHandler() { + return params -> { + McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); + if (registration != null) { + return registration.readHandler().apply(null, resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptRegistration The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + if (promptRegistration == null) { + return Mono.error(new McpError("Prompt registration must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification registration = this.prompts + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration.toSpecification()); + if (registration != null) { + return Mono.error(new McpError( + "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler promptsListRequestHandler() { + return params -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpClientSession.RequestHandler promptsGetRequestHandler() { + return params -> { + McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification registration = this.prompts.get(promptRequest.name()); + if (registration == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return registration.promptHandler().apply(null, promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * Send a logging message notification to all connected clients. Messages below + * the current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.transport.unmarshalFrom(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + /** + * Handles requests to set the minimum logging level. Messages below this level + * will not be sent. + * @return A handler that processes logging level change requests + */ + private McpClientSession.RequestHandler setLoggerRequestHandler() { + return params -> { + this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. + * This flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server + * API keys necessary. Servers can request text or image-based interactions and + * optionally include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * This method is package-private and used for test only. Should not be called by + * user code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java new file mode 100644 index 000000000..658628448 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -0,0 +1,104 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import reactor.core.publisher.Mono; + +/** + * Represents an asynchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ +public class McpAsyncServerExchange { + + private final McpServerSession session; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + */ + public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo) { + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java similarity index 57% rename from mcp/src/main/java/org/springframework/ai/mcp/server/McpServer.java rename to mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 615fd515d..d8dfcb018 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -1,38 +1,28 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ResourceTemplate; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; -import org.springframework.ai.mcp.util.Assert; /** * Factory class for creating Model Context Protocol (MCP) servers. MCP servers expose @@ -64,45 +54,50 @@ *

    * The class provides factory methods to create either: *

      - *
    • {@link McpAsyncServer} for non-blocking operations with CompletableFuture responses + *
    • {@link McpAsyncServer} for non-blocking operations with reactive responses *
    • {@link McpSyncServer} for blocking operations with direct responses *
    * *

    * Example of creating a basic synchronous server:

    {@code
    - * McpServer.sync(transport)
    + * McpServer.sync(transportProvider)
      *     .serverInfo("my-server", "1.0.0")
      *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> new CallToolResult("Result: " + calculate(args)))
    + *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
      *     .build();
      * }
    * * Example of creating a basic asynchronous server:
    {@code
    - * McpServer.async(transport)
    + * McpServer.async(transportProvider)
      *     .serverInfo("my-server", "1.0.0")
      *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> Mono.just(new CallToolResult("Result: " + calculate(args))))
    + *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    + *               .map(result -> new CallToolResult("Result: " + result)))
      *     .build();
      * }
    * *

    * Example with comprehensive asynchronous configuration:

    {@code
    - * McpServer.async(transport)
    + * McpServer.async(transportProvider)
      *     .serverInfo("advanced-server", "2.0.0")
      *     .capabilities(new ServerCapabilities(...))
      *     // Register tools
      *     .tools(
    - *         new McpServerFeatures.AsyncToolRegistration(calculatorTool,
    - *             args -> Mono.just(new CallToolResult("Result: " + calculate(args)))),
    - *         new McpServerFeatures.AsyncToolRegistration(weatherTool,
    - *             args -> Mono.just(new CallToolResult("Weather: " + getWeather(args))))
    + *         new McpServerFeatures.AsyncToolSpecification(calculatorTool,
    + *             (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    + *                 .map(result -> new CallToolResult("Result: " + result))),
    + *         new McpServerFeatures.AsyncToolSpecification(weatherTool,
    + *             (exchange, args) -> Mono.fromSupplier(() -> getWeather(args))
    + *                 .map(result -> new CallToolResult("Weather: " + result)))
      *     )
      *     // Register resources
      *     .resources(
    - *         new McpServerFeatures.AsyncResourceRegistration(fileResource,
    - *             req -> Mono.just(new ReadResourceResult(readFile(req)))),
    - *         new McpServerFeatures.AsyncResourceRegistration(dbResource,
    - *             req -> Mono.just(new ReadResourceResult(queryDb(req))))
    + *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
    + *                 .map(ReadResourceResult::new)),
    + *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
    + *                 .map(ReadResourceResult::new))
      *     )
      *     // Add resource templates
      *     .resourceTemplates(
    @@ -111,10 +106,12 @@
      *     )
      *     // Register prompts
      *     .prompts(
    - *         new McpServerFeatures.AsyncPromptRegistration(analysisPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateAnalysisPrompt(req)))),
    + *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
    + *                 .map(GetPromptResult::new)),
      *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateSummaryPrompt(req))))
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
    + *                 .map(GetPromptResult::new))
      *     )
      *     .build();
      * }
    @@ -123,55 +120,75 @@ * @author Dariusz Jędrzejczyk * @see McpAsyncServer * @see McpSyncServer - * @see McpTransport + * @see McpServerTransportProvider */ public interface McpServer { /** * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers process each request to completion before handling the next - * one, making them simpler to implement but potentially less performant for - * concurrent operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static SyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link SyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpServerTransportProvider)} instead. */ + @Deprecated static SyncSpec sync(ServerMcpTransport transport) { return new SyncSpec(transport); } /** - * Starts building an asynchronous MCP server that provides blocking operations. - * Asynchronous servers can handle multiple requests concurrently using a functional - * paradigm with non-blocking server transports, making them more efficient for - * high-concurrency scenarios but more complex to implement. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static AsyncSpec async(ServerMcpTransport transport) { - return new AsyncSpec(transport); + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new AsyncSpecification(transportProvider); } /** - * Start building an MCP server with the specified transport. + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. * @param transport The transport layer implementation for MCP communication - * @return A new builder instance - * @deprecated Use {@link #sync(ServerMcpTransport)} or - * {@link #async(ServerMcpTransport)} to create a server instance. + * @return A new instance of {@link AsyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpServerTransportProvider)} instead. */ @Deprecated - public static Builder using(ServerMcpTransport transport) { - return new Builder(transport); + static AsyncSpec async(ServerMcpTransport transport) { + return new AsyncSpec(transport); } /** * Asynchronous server specification. */ - class AsyncSpec { + class AsyncSpecification { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); - private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -184,7 +201,7 @@ class AsyncSpec { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + private final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -193,7 +210,7 @@ class AsyncSpec { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + private final Map resources = new HashMap<>(); private final List resourceTemplates = new ArrayList<>(); @@ -204,13 +221,13 @@ class AsyncSpec { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + private final Map prompts = new HashMap<>(); - private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - private AsyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; + private AsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; } /** @@ -222,7 +239,7 @@ private AsyncSpec(ServerMcpTransport transport) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public AsyncSpec serverInfo(McpSchema.Implementation serverInfo) { + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -238,7 +255,7 @@ public AsyncSpec serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public AsyncSpec serverInfo(String name, String version) { + public AsyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -253,15 +270,14 @@ public AsyncSpec serverInfo(String name, String version) { *
  • Tool execution *
  • Resource access *
  • Prompt handling - *
  • Streaming responses - *
  • Batch operations * * @param serverCapabilities The server capabilities configuration. Must not be * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public AsyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; } @@ -269,26 +285,31 @@ public AsyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.AsyncToolRegistration} explicitly. + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. * *

    * Example usage:

    {@code
     		 * .tool(
     		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
    +		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    +		 *         .map(result -> new CallToolResult("Result: " + result))
     		 * )
     		 * }
    * @param tool The tool definition including name, description, and schema. Must * not be null. * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpAsyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * map of arguments passed to the tool. * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public AsyncSpec tool(McpSchema.Tool tool, Function, Mono> handler) { + public AsyncSpecification tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); - this.tools.add(new McpServerFeatures.AsyncToolRegistration(tool, handler)); + this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); return this; } @@ -297,15 +318,15 @@ public AsyncSpec tool(McpSchema.Tool tool, Function, Mono toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); + public AsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); return this; } @@ -316,18 +337,19 @@ public AsyncSpec tools(List toolRegistr *

    * Example usage:

    {@code
     		 * .tools(
    -		 *     new McpServerFeatures.AsyncToolRegistration(calculatorTool, calculatorHandler),
    -		 *     new McpServerFeatures.AsyncToolRegistration(weatherTool, weatherHandler),
    -		 *     new McpServerFeatures.AsyncToolRegistration(fileManagerTool, fileManagerHandler)
    +		 *     new McpServerFeatures.AsyncToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(weatherTool, weatherHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(fileManagerTool, fileManagerHandler)
     		 * )
     		 * }
    - * @param toolRegistrations The tool registrations to add. Must not be null. + * @param toolSpecifications The tool specifications to add. Must not be null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null + * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(List) */ - public AsyncSpec tools(McpServerFeatures.AsyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.AsyncToolRegistration tool : toolRegistrations) { + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { this.tools.add(tool); } return this; @@ -337,29 +359,31 @@ public AsyncSpec tools(McpServerFeatures.AsyncToolRegistration... toolRegistrati * Registers multiple resources with their handlers using a Map. This method is * useful when resources are dynamically generated or loaded from a configuration * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); + public AsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); return this; } /** * Registers multiple resources with their handlers using a List. This method is * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. + * @param resourceSpecifications List of resource specifications. Must not be + * null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegsitrations) { + public AsyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -372,19 +396,19 @@ public AsyncSpec resources(List res *

    * Example usage:

    {@code
     		 * .resources(
    -		 *     new McpServerFeatures.AsyncResourceRegistration(fileResource, fileHandler),
    -		 *     new McpServerFeatures.AsyncResourceRegistration(dbResource, dbHandler),
    -		 *     new McpServerFeatures.AsyncResourceRegistration(apiResource, apiHandler)
    +		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
     		 * )
     		 * }
    - * @param resourceRegistrations The resource registrations to add. Must not be + * @param resourceSpecifications The resource specifications to add. Must not be * null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null + * @throws IllegalArgumentException if resourceSpecifications is null */ - public AsyncSpec resources(McpServerFeatures.AsyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegistrations) { + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -404,9 +428,11 @@ public AsyncSpec resources(McpServerFeatures.AsyncResourceRegistration... resour * @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public AsyncSpec resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; } @@ -416,9 +442,11 @@ public AsyncSpec resourceTemplates(List resourceTemplates) { * alternative to {@link #resourceTemplates(List)}. * @param resourceTemplates The resource templates to set. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } @@ -432,16 +460,18 @@ public AsyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { * *

    * Example usage:

    {@code
    -		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptRegistration(
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
     		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
     		 * )));
     		 * }
    - * @param prompts Map of prompt name to registration. Must not be null. + * @param prompts Map of prompt name to specification. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public AsyncSpec prompts(Map prompts) { + public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; } @@ -449,13 +479,14 @@ public AsyncSpec prompts(Map /** * Registers multiple prompts with their handlers using a List. This method is * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. + * @param prompts List of prompt specifications. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.AsyncPromptRegistration...) + * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) */ - public AsyncSpec prompts(List prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { + public AsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -468,17 +499,18 @@ public AsyncSpec prompts(List prompts *

    * Example usage:

    {@code
     		 * .prompts(
    -		 *     new McpServerFeatures.AsyncPromptRegistration(analysisPrompt, analysisHandler),
    -		 *     new McpServerFeatures.AsyncPromptRegistration(summaryPrompt, summaryHandler),
    -		 *     new McpServerFeatures.AsyncPromptRegistration(reviewPrompt, reviewHandler)
    +		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
     		 * )
     		 * }
    - * @param prompts The prompt registrations to add. Must not be null. + * @param prompts The prompt specifications to add. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -488,13 +520,16 @@ public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new * files are added or removed. - * @param consumer The consumer to register. Must not be null. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second argument is the list of roots. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); return this; } @@ -502,13 +537,15 @@ public AsyncSpec rootsChangeConsumer(Function, Mono> * Registers multiple consumers that will be notified when the list of roots * changes. This method is useful when multiple consumers need to be registered at * once. - * @param consumers The list of consumers to register. Must not be null. + * @param handlers The list of handlers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiFunction) */ - public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); + public AsyncSpecification rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); return this; } @@ -516,26 +553,39 @@ public AsyncSpec rootsChangeConsumers(List, Mono, Mono>... consumers) { - for (Function, Mono> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } + public AsyncSpecification rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; return this; } /** * Builds an asynchronous MCP server that provides non-blocking operations. * @return A new instance of {@link McpAsyncServer} configured with this builder's - * settings + * settings. */ public McpAsyncServer build() { - return new McpAsyncServer(this.transport, - new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, - this.resourceTemplates, this.prompts, this.rootsChangeConsumers)); + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + return new McpAsyncServer(this.transportProvider, mapper, features); } } @@ -543,12 +593,14 @@ public McpAsyncServer build() { /** * Synchronous server specification. */ - class SyncSpec { + class SyncSpecification { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); - private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -561,7 +613,7 @@ class SyncSpec { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + private final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -570,7 +622,7 @@ class SyncSpec { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + private final Map resources = new HashMap<>(); private final List resourceTemplates = new ArrayList<>(); @@ -581,13 +633,13 @@ class SyncSpec { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + private final Map prompts = new HashMap<>(); - private final List>> rootsChangeConsumers = new ArrayList<>(); + private final List>> rootsChangeHandlers = new ArrayList<>(); - private SyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; + private SyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; } /** @@ -599,7 +651,7 @@ private SyncSpec(ServerMcpTransport transport) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public SyncSpec serverInfo(McpSchema.Implementation serverInfo) { + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -615,7 +667,7 @@ public SyncSpec serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public SyncSpec serverInfo(String name, String version) { + public SyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -630,15 +682,14 @@ public SyncSpec serverInfo(String name, String version) { *
  • Tool execution *
  • Resource access *
  • Prompt handling - *
  • Streaming responses - *
  • Batch operations * * @param serverCapabilities The server capabilities configuration. Must not be * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; } @@ -646,26 +697,30 @@ public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a - * {@link ToolRegistration} explicitly. + * {@link McpServerFeatures.SyncToolSpecification} explicitly. * *

    * Example usage:

    {@code
     		 * .tool(
     		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> new CallToolResult("Result: " + calculate(args))
    +		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
     		 * )
     		 * }
    * @param tool The tool definition including name, description, and schema. Must * not be null. * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public SyncSpec tool(McpSchema.Tool tool, Function, McpSchema.CallToolResult> handler) { + public SyncSpecification tool(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); - this.tools.add(new McpServerFeatures.SyncToolRegistration(tool, handler)); + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); return this; } @@ -674,15 +729,15 @@ public SyncSpec tool(McpSchema.Tool tool, Function, McpSchem * Adds multiple tools with their handlers to the server using a List. This method * is useful when tools are dynamically generated or loaded from a configuration * source. - * @param toolRegistrations The list of tool registrations to add. Must not be + * @param toolSpecifications The list of tool specifications to add. Must not be * null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(McpServerFeatures.SyncToolRegistration...) + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.SyncToolSpecification...) */ - public SyncSpec tools(List toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); + public SyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); return this; } @@ -693,18 +748,19 @@ public SyncSpec tools(List toolRegistrat *

    * Example usage:

    {@code
     		 * .tools(
    -		 *     new ToolRegistration(calculatorTool, calculatorHandler),
    -		 *     new ToolRegistration(weatherTool, weatherHandler),
    -		 *     new ToolRegistration(fileManagerTool, fileManagerHandler)
    +		 *     new ToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new ToolSpecification(weatherTool, weatherHandler),
    +		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
     		 * )
     		 * }
    - * @param toolRegistrations The tool registrations to add. Must not be null. + * @param toolSpecifications The tool specifications to add. Must not be null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null + * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(List) */ - public SyncSpec tools(McpServerFeatures.SyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.SyncToolRegistration tool : toolRegistrations) { + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { this.tools.add(tool); } return this; @@ -714,29 +770,31 @@ public SyncSpec tools(McpServerFeatures.SyncToolRegistration... toolRegistration * Registers multiple resources with their handlers using a Map. This method is * useful when resources are dynamically generated or loaded from a configuration * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); + public SyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); return this; } /** * Registers multiple resources with their handlers using a List. This method is * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. + * @param resourceSpecifications List of resource specifications. Must not be + * null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegsitrations) { + public SyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -749,19 +807,19 @@ public SyncSpec resources(List resou *

    * Example usage:

    {@code
     		 * .resources(
    -		 *     new ResourceRegistration(fileResource, fileHandler),
    -		 *     new ResourceRegistration(dbResource, dbHandler),
    -		 *     new ResourceRegistration(apiResource, apiHandler)
    +		 *     new ResourceSpecification(fileResource, fileHandler),
    +		 *     new ResourceSpecification(dbResource, dbHandler),
    +		 *     new ResourceSpecification(apiResource, apiHandler)
     		 * )
     		 * }
    - * @param resourceRegistrations The resource registrations to add. Must not be + * @param resourceSpecifications The resource specifications to add. Must not be * null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null + * @throws IllegalArgumentException if resourceSpecifications is null */ - public SyncSpec resources(McpServerFeatures.SyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegistrations) { + public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -781,9 +839,11 @@ public SyncSpec resources(McpServerFeatures.SyncResourceRegistration... resource * @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public SyncSpec resourceTemplates(List resourceTemplates) { + public SyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; } @@ -793,9 +853,11 @@ public SyncSpec resourceTemplates(List resourceTemplates) { * alternative to {@link #resourceTemplates(List)}. * @param resourceTemplates The resource templates to set. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { + public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } @@ -809,18 +871,19 @@ public SyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { * *

    * Example usage:

    {@code
    -		 * Map prompts = new HashMap<>();
    -		 * prompts.put("analysis", new PromptRegistration(
    +		 * Map prompts = new HashMap<>();
    +		 * prompts.put("analysis", new PromptSpecification(
     		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
    +		 *     (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request))
     		 * ));
     		 * .prompts(prompts)
     		 * }
    - * @param prompts Map of prompt name to registration. Must not be null. + * @param prompts Map of prompt name to specification. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpec prompts(Map prompts) { + public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; } @@ -828,13 +891,14 @@ public SyncSpec prompts(Map pr /** * Registers multiple prompts with their handlers using a List. This method is * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. + * @param prompts List of prompt specifications. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.SyncPromptRegistration...) + * @see #prompts(McpServerFeatures.SyncPromptSpecification...) */ - public SyncSpec prompts(List prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { + public SyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -847,17 +911,18 @@ public SyncSpec prompts(List prompts) *

    * Example usage:

    {@code
     		 * .prompts(
    -		 *     new PromptRegistration(analysisPrompt, analysisHandler),
    -		 *     new PromptRegistration(summaryPrompt, summaryHandler),
    -		 *     new PromptRegistration(reviewPrompt, reviewHandler)
    +		 *     new PromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new PromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new PromptSpecification(reviewPrompt, reviewHandler)
     		 * )
     		 * }
    - * @param prompts The prompt registrations to add. Must not be null. + * @param prompts The prompt specifications to add. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -867,13 +932,15 @@ public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new * files are added or removed. - * @param consumer The consumer to register. Must not be null. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpSyncServerExchange} upon which the server can interact + * with the connected client. The second argument is the list of roots. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public SyncSpec rootsChangeConsumer(Consumer> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); + public SyncSpecification rootsChangeHandler(BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); return this; } @@ -881,13 +948,15 @@ public SyncSpec rootsChangeConsumer(Consumer> consumer) { * Registers multiple consumers that will be notified when the list of roots * changes. This method is useful when multiple consumers need to be registered at * once. - * @param consumers The list of consumers to register. Must not be null. + * @param handlers The list of handlers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiConsumer) */ - public SyncSpec rootsChangeConsumers(List>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); + public SyncSpecification rootsChangeHandlers( + List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); return this; } @@ -895,45 +964,61 @@ public SyncSpec rootsChangeConsumers(List>> consum * Registers multiple consumers that will be notified when the list of roots * changes using varargs. This method provides a convenient way to register * multiple consumers inline. - * @param consumers The consumers to register. Must not be null. + * @param handlers The handlers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) */ - public SyncSpec rootsChangeConsumers(Consumer>... consumers) { - for (Consumer> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } + public SyncSpecification rootsChangeHandlers( + BiConsumer>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(List.of(handlers)); + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public SyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; return this; } /** * Builds a synchronous MCP server that provides blocking operations. * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings + * settings. */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); - return new McpSyncServer( - new McpAsyncServer(this.transport, McpServerFeatures.Async.fromSync(syncFeatures))); + this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + + return new McpSyncServer(asyncServer); } } /** - * Builder class for creating MCP servers with custom configuration. + * Asynchronous server specification. * - * @deprecated Use {@link #sync(ServerMcpTransport)} or - * {@link #async(ServerMcpTransport)} to create a server. + * @deprecated */ @Deprecated - public static class Builder { + class AsyncSpec { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); private final ServerMcpTransport transport; + private ObjectMapper objectMapper; + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; private McpSchema.ServerCapabilities serverCapabilities; @@ -945,7 +1030,7 @@ public static class Builder { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + private final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -954,9 +1039,9 @@ public static class Builder { * application-specific information. Each resource is uniquely identified by a * URI. */ - private Map resources = new HashMap<>(); + private final Map resources = new HashMap<>(); - private List resourceTemplates = new ArrayList<>(); + private final List resourceTemplates = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -965,11 +1050,11 @@ public static class Builder { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private Map prompts = new HashMap<>(); + private final Map prompts = new HashMap<>(); - private List>> rootsChangeConsumers = new ArrayList<>(); + private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); - private Builder(ServerMcpTransport transport) { + private AsyncSpec(ServerMcpTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; } @@ -983,7 +1068,7 @@ private Builder(ServerMcpTransport transport) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public Builder serverInfo(McpSchema.Implementation serverInfo) { + public AsyncSpec serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -999,7 +1084,7 @@ public Builder serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public Builder serverInfo(String name, String version) { + public AsyncSpec serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -1022,7 +1107,7 @@ public Builder serverInfo(String name, String version) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public Builder capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public AsyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { this.serverCapabilities = serverCapabilities; return this; } @@ -1030,13 +1115,13 @@ public Builder capabilities(McpSchema.ServerCapabilities serverCapabilities) { /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a - * {@link ToolRegistration} explicitly. + * {@link McpServerFeatures.AsyncToolRegistration} explicitly. * *

    * Example usage:

    {@code
     		 * .tool(
     		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> new CallToolResult("Result: " + calculate(args))
    +		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
     		 * )
     		 * }
    * @param tool The tool definition including name, description, and schema. Must @@ -1045,11 +1130,11 @@ public Builder capabilities(McpSchema.ServerCapabilities serverCapabilities) { * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public Builder tool(McpSchema.Tool tool, Function, McpSchema.CallToolResult> handler) { + public AsyncSpec tool(McpSchema.Tool tool, Function, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); - this.tools.add(new ToolRegistration(tool, handler)); + this.tools.add(new McpServerFeatures.AsyncToolRegistration(tool, handler)); return this; } @@ -1062,9 +1147,9 @@ public Builder tool(McpSchema.Tool tool, Function, McpSchema * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(ToolRegistration...) + * @see #tools(McpServerFeatures.AsyncToolRegistration...) */ - public Builder tools(List toolRegistrations) { + public AsyncSpec tools(List toolRegistrations) { Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); this.tools.addAll(toolRegistrations); return this; @@ -1077,9 +1162,9 @@ public Builder tools(List toolRegistrations) { *

    * Example usage:

    {@code
     		 * .tools(
    -		 *     new ToolRegistration(calculatorTool, calculatorHandler),
    -		 *     new ToolRegistration(weatherTool, weatherHandler),
    -		 *     new ToolRegistration(fileManagerTool, fileManagerHandler)
    +		 *     new McpServerFeatures.AsyncToolRegistration(calculatorTool, calculatorHandler),
    +		 *     new McpServerFeatures.AsyncToolRegistration(weatherTool, weatherHandler),
    +		 *     new McpServerFeatures.AsyncToolRegistration(fileManagerTool, fileManagerHandler)
     		 * )
     		 * }
    * @param toolRegistrations The tool registrations to add. Must not be null. @@ -1087,8 +1172,8 @@ public Builder tools(List toolRegistrations) { * @throws IllegalArgumentException if toolRegistrations is null * @see #tools(List) */ - public Builder tools(ToolRegistration... toolRegistrations) { - for (ToolRegistration tool : toolRegistrations) { + public AsyncSpec tools(McpServerFeatures.AsyncToolRegistration... toolRegistrations) { + for (McpServerFeatures.AsyncToolRegistration tool : toolRegistrations) { this.tools.add(tool); } return this; @@ -1102,9 +1187,9 @@ public Builder tools(ToolRegistration... toolRegistrations) { * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(ResourceRegistration...) + * @see #resources(McpServerFeatures.AsyncResourceRegistration...) */ - public Builder resources(Map resourceRegsitrations) { + public AsyncSpec resources(Map resourceRegsitrations) { Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); this.resources.putAll(resourceRegsitrations); return this; @@ -1116,11 +1201,11 @@ public Builder resources(Map resourceRegsitrations * @param resourceRegsitrations List of resource registrations. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(ResourceRegistration...) + * @see #resources(McpServerFeatures.AsyncResourceRegistration...) */ - public Builder resources(List resourceRegsitrations) { + public AsyncSpec resources(List resourceRegsitrations) { Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (ResourceRegistration resource : resourceRegsitrations) { + for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegsitrations) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -1133,9 +1218,9 @@ public Builder resources(List resourceRegsitrations) { *

    * Example usage:

    {@code
     		 * .resources(
    -		 *     new ResourceRegistration(fileResource, fileHandler),
    -		 *     new ResourceRegistration(dbResource, dbHandler),
    -		 *     new ResourceRegistration(apiResource, apiHandler)
    +		 *     new McpServerFeatures.AsyncResourceRegistration(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceRegistration(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceRegistration(apiResource, apiHandler)
     		 * )
     		 * }
    * @param resourceRegistrations The resource registrations to add. Must not be @@ -1143,9 +1228,9 @@ public Builder resources(List resourceRegsitrations) { * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceRegistrations is null */ - public Builder resources(ResourceRegistration... resourceRegistrations) { + public AsyncSpec resources(McpServerFeatures.AsyncResourceRegistration... resourceRegistrations) { Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (ResourceRegistration resource : resourceRegistrations) { + for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegistrations) { this.resources.put(resource.resource().uri(), resource); } return this; @@ -1167,8 +1252,8 @@ public Builder resources(ResourceRegistration... resourceRegistrations) { * @return This builder instance for method chaining * @see #resourceTemplates(ResourceTemplate...) */ - public Builder resourceTemplates(List resourceTemplates) { - this.resourceTemplates = resourceTemplates; + public AsyncSpec resourceTemplates(List resourceTemplates) { + this.resourceTemplates.addAll(resourceTemplates); return this; } @@ -1179,7 +1264,7 @@ public Builder resourceTemplates(List resourceTemplates) { * @return This builder instance for method chaining * @see #resourceTemplates(List) */ - public Builder resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } @@ -1193,18 +1278,16 @@ public Builder resourceTemplates(ResourceTemplate... resourceTemplates) { * *

    * Example usage:

    {@code
    -		 * Map prompts = new HashMap<>();
    -		 * prompts.put("analysis", new PromptRegistration(
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptRegistration(
     		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
    -		 * ));
    -		 * .prompts(prompts)
    +		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
    +		 * )));
     		 * }
    * @param prompts Map of prompt name to registration. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public Builder prompts(Map prompts) { + public AsyncSpec prompts(Map prompts) { this.prompts.putAll(prompts); return this; } @@ -1215,10 +1298,10 @@ public Builder prompts(Map prompts) { * @param prompts List of prompt registrations. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null - * @see #prompts(PromptRegistration...) + * @see #prompts(McpServerFeatures.AsyncPromptRegistration...) */ - public Builder prompts(List prompts) { - for (PromptRegistration prompt : prompts) { + public AsyncSpec prompts(List prompts) { + for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -1231,17 +1314,17 @@ public Builder prompts(List prompts) { *

    * Example usage:

    {@code
     		 * .prompts(
    -		 *     new PromptRegistration(analysisPrompt, analysisHandler),
    -		 *     new PromptRegistration(summaryPrompt, summaryHandler),
    -		 *     new PromptRegistration(reviewPrompt, reviewHandler)
    +		 *     new McpServerFeatures.AsyncPromptRegistration(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptRegistration(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptRegistration(reviewPrompt, reviewHandler)
     		 * )
     		 * }
    * @param prompts The prompt registrations to add. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public Builder prompts(PromptRegistration... prompts) { - for (PromptRegistration prompt : prompts) { + public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { + for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } return this; @@ -1255,7 +1338,7 @@ public Builder prompts(PromptRegistration... prompts) { * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public Builder rootsChangeConsumer(Consumer> consumer) { + public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { Assert.notNull(consumer, "Consumer must not be null"); this.rootsChangeConsumers.add(consumer); return this; @@ -1269,7 +1352,7 @@ public Builder rootsChangeConsumer(Consumer> consumer) { * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null */ - public Builder rootsChangeConsumers(List>> consumers) { + public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { Assert.notNull(consumers, "Consumers list must not be null"); this.rootsChangeConsumers.addAll(consumers); return this; @@ -1283,184 +1366,457 @@ public Builder rootsChangeConsumers(List>> consume * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null */ - public Builder rootsChangeConsumers(Consumer>... consumers) { - for (Consumer> consumer : consumers) { + public AsyncSpec rootsChangeConsumers( + @SuppressWarnings("unchecked") Function, Mono>... consumers) { + for (Function, Mono> consumer : consumers) { this.rootsChangeConsumers.add(consumer); } return this; } - /** - * Builds a synchronous MCP server that provides blocking operations. Synchronous - * servers process each request to completion before handling the next one, making - * them simpler to implement but potentially less performant for concurrent - * operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings - * @deprecated Use {@link #sync(ServerMcpTransport)}. - */ - @Deprecated - public McpSyncServer sync() { - return new McpSyncServer(async()); - } - /** * Builds an asynchronous MCP server that provides non-blocking operations. - * Asynchronous servers can handle multiple requests concurrently using - * CompletableFuture, making them more efficient for high-concurrency scenarios - * but more complex to implement. * @return A new instance of {@link McpAsyncServer} configured with this builder's * settings - * @deprecated Use {@link #async(ServerMcpTransport)} */ - @Deprecated - public McpAsyncServer async() { - return new McpAsyncServer(transport, serverInfo, serverCapabilities, tools, resources, resourceTemplates, - prompts, rootsChangeConsumers); + public McpAsyncServer build() { + var tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::toSpecification).toList(); + + var resources = this.resources.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var prompts = this.prompts.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var rootsChangeHandlers = this.rootsChangeConsumers.stream() + .map(consumer -> (BiFunction, Mono>) (exchange, + roots) -> consumer.apply(roots)) + .toList(); + + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, tools, resources, + this.resourceTemplates, prompts, rootsChangeHandlers); + + return new McpAsyncServer(this.transport, features); } } /** - * Registration of a tool with its handler function. Tools are the primary way for MCP - * servers to expose functionality to AI models. Each tool represents a specific - * capability, such as: - *
      - *
    • Performing calculations - *
    • Accessing external APIs - *
    • Querying databases - *
    • Manipulating files - *
    • Executing system commands - *
    - * - *

    - * Example tool registration:

    {@code
    -	 * new ToolRegistration(
    -	 *     new Tool(
    -	 *         "calculator",
    -	 *         "Performs mathematical calculations",
    -	 *         new JsonSchemaObject()
    -	 *             .required("expression")
    -	 *             .property("expression", JsonSchemaType.STRING)
    -	 *     ),
    -	 *     args -> {
    -	 *         String expr = (String) args.get("expression");
    -	 *         return new CallToolResult("Result: " + evaluate(expr));
    -	 *     }
    -	 * )
    -	 * }
    + * Synchronous server specification. * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - * @deprecated Use {@link McpServerFeatures.SyncToolRegistration} or - * {@link McpServerFeatures.AsyncToolRegistration}. + * @deprecated */ @Deprecated - public static record ToolRegistration(McpSchema.Tool tool, - Function, McpSchema.CallToolResult> call) { - } + class SyncSpec { - /** - * Registration of a resource with its handler function. Resources provide context to - * AI models by exposing data such as: - *
      - *
    • File contents - *
    • Database records - *
    • API responses - *
    • System information - *
    • Application state - *
    - * - *

    - * Example resource registration:

    {@code
    -	 * new ResourceRegistration(
    -	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     request -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return new ReadResourceResult(content);
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - * @deprecated Use {@link McpServerFeatures.SyncResourceRegistration} or - * {@link McpServerFeatures.AsyncResourceRegistration}. - */ - @Deprecated - public static record ResourceRegistration(McpSchema.Resource resource, - Function readHandler) { - } + private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", + "1.0.0"); - /** - * Registration of a prompt template with its handler function. Prompts provide - * structured templates for AI model interactions, supporting: - *
      - *
    • Consistent message formatting - *
    • Parameter substitution - *
    • Context injection - *
    • Response formatting - *
    • Instruction templating - *
    - * - *

    - * Example prompt registration:

    {@code
    -	 * new PromptRegistration(
    -	 *     new Prompt("analyze", "Code analysis template"),
    -	 *     request -> {
    -	 *         String code = request.getArguments().get("code");
    -	 *         return new GetPromptResult(
    -	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    -	 *         );
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - * @deprecated Use {@link McpServerFeatures.SyncPromptRegistration} or - * {@link McpServerFeatures.AsyncPromptRegistration}. - */ - @Deprecated - public static record PromptRegistration(McpSchema.Prompt prompt, - Function promptHandler) { - } + private final ServerMcpTransport transport; - static McpServerFeatures.AsyncToolRegistration mapDeprecatedTool(ToolRegistration oldTool) { - return new McpServerFeatures.AsyncToolRegistration(oldTool.tool(), - map -> Mono.fromCallable(() -> oldTool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); - } + private final McpServerTransportProvider transportProvider; - static McpServerFeatures.AsyncResourceRegistration mapDeprecatedResource(ResourceRegistration oldResource) { - return new McpServerFeatures.AsyncResourceRegistration(oldResource.resource(), - req -> Mono.fromCallable(() -> oldResource.readHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } + private ObjectMapper objectMapper; - static McpServerFeatures.AsyncPromptRegistration mapDeprecatedPrompt(PromptRegistration oldPrompt) { - return new McpServerFeatures.AsyncPromptRegistration(oldPrompt.prompt(), - req -> Mono.fromCallable(() -> oldPrompt.promptHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - static List mapDeprecatedTools(List oldTools) { - return oldTools.stream().map(McpServer::mapDeprecatedTool).toList(); - } + private McpSchema.ServerCapabilities serverCapabilities; - static Map mapDeprecatedResources( - Map oldResources) { - return oldResources.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> mapDeprecatedResource(e.getValue()))); - } + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + private final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + private final Map prompts = new HashMap<>(); + + private final List>> rootsChangeConsumers = new ArrayList<>(); + + private SyncSpec(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + this.transport = null; + } + + private SyncSpec(ServerMcpTransport transport) { + Assert.notNull(transport, "Transport must not be null"); + this.transport = transport; + this.transportProvider = null; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public SyncSpec serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public SyncSpec serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    • Streaming responses + *
    • Batch operations + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolRegistration} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     new Tool("calculator", "Performs calculations", schema),
    +		 *     args -> new CallToolResult("Result: " + calculate(args))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public SyncSpec tool(McpSchema.Tool tool, Function, McpSchema.CallToolResult> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + + this.tools.add(new McpServerFeatures.SyncToolRegistration(tool, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolRegistrations The list of tool registrations to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolRegistrations is null + * @see #tools(McpServerFeatures.SyncToolRegistration...) + */ + public SyncSpec tools(List toolRegistrations) { + Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); + this.tools.addAll(toolRegistrations); + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new ToolRegistration(calculatorTool, calculatorHandler),
    +		 *     new ToolRegistration(weatherTool, weatherHandler),
    +		 *     new ToolRegistration(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolRegistrations The tool registrations to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolRegistrations is null + * @see #tools(List) + */ + public SyncSpec tools(McpServerFeatures.SyncToolRegistration... toolRegistrations) { + for (McpServerFeatures.SyncToolRegistration tool : toolRegistrations) { + this.tools.add(tool); + } + return this; + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceRegsitrations Map of resource name to registration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceRegsitrations is null + * @see #resources(McpServerFeatures.SyncResourceRegistration...) + */ + public SyncSpec resources(Map resourceRegsitrations) { + Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); + this.resources.putAll(resourceRegsitrations); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceRegsitrations List of resource registrations. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceRegsitrations is null + * @see #resources(McpServerFeatures.SyncResourceRegistration...) + */ + public SyncSpec resources(List resourceRegsitrations) { + Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceRegistration resource : resourceRegsitrations) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new ResourceRegistration(fileResource, fileHandler),
    +		 *     new ResourceRegistration(dbResource, dbHandler),
    +		 *     new ResourceRegistration(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceRegistrations The resource registrations to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceRegistrations is null + */ + public SyncSpec resources(McpServerFeatures.SyncResourceRegistration... resourceRegistrations) { + Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceRegistration resource : resourceRegistrations) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

    + * Example usage:

    {@code
    +		 * .resourceTemplates(
    +		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    +		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    +		 * )
    +		 * }
    + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @see #resourceTemplates(ResourceTemplate...) + */ + public SyncSpec resourceTemplates(List resourceTemplates) { + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @see #resourceTemplates(List) + */ + public SyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * Map prompts = new HashMap<>();
    +		 * prompts.put("analysis", new PromptRegistration(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
    +		 * ));
    +		 * .prompts(prompts)
    +		 * }
    + * @param prompts Map of prompt name to registration. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpec prompts(Map prompts) { + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt registrations. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.SyncPromptRegistration...) + */ + public SyncSpec prompts(List prompts) { + for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new PromptRegistration(analysisPrompt, analysisHandler),
    +		 *     new PromptRegistration(summaryPrompt, summaryHandler),
    +		 *     new PromptRegistration(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt registrations to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { + for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param consumer The consumer to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public SyncSpec rootsChangeConsumer(Consumer> consumer) { + Assert.notNull(consumer, "Consumer must not be null"); + this.rootsChangeConsumers.add(consumer); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param consumers The list of consumers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + */ + public SyncSpec rootsChangeConsumers(List>> consumers) { + Assert.notNull(consumers, "Consumers list must not be null"); + this.rootsChangeConsumers.addAll(consumers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param consumers The consumers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + */ + public SyncSpec rootsChangeConsumers(Consumer>... consumers) { + for (Consumer> consumer : consumers) { + this.rootsChangeConsumers.add(consumer); + } + return this; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings + */ + public McpSyncServer build() { + var tools = this.tools.stream().map(McpServerFeatures.SyncToolRegistration::toSpecification).toList(); + + var resources = this.resources.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var prompts = this.prompts.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var rootsChangeHandlers = this.rootsChangeConsumers.stream() + .map(consumer -> (BiConsumer>) (exchange, roots) -> consumer + .accept(roots)) + .toList(); + + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + tools, resources, this.resourceTemplates, prompts, rootsChangeHandlers); + + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var asyncServer = new McpAsyncServer(this.transport, asyncFeatures); + + return new McpSyncServer(asyncServer); + } - static Map mapDeprecatedPrompts( - Map oldPrompts) { - return oldPrompts.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> mapDeprecatedPrompt(e.getValue()))); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java new file mode 100644 index 000000000..5aeeadd77 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -0,0 +1,694 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * MCP server features specification that a particular server can choose to support. + * + * @author Dariusz Jędrzejczyk + */ +public class McpServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when the + * roots list changes + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + List, Mono>> rootsChangeConsumers) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + List, Mono>> rootsChangeConsumers) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); + }); + + List, Mono>> rootChangeConsumers = new ArrayList<>(); + + for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { + rootChangeConsumers.add((exchange, list) -> Mono + .fromRunnable(() -> rootChangeConsumer.accept(new McpSyncServerExchange(exchange), list)) + .subscribeOn(Schedulers.boundedElastic())); + } + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, + syncSpec.resourceTemplates(), prompts, rootChangeConsumers); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when the + * roots list changes + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + List>> rootsChangeConsumers) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + List>> rootsChangeConsumers) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.AsyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return Mono.fromSupplier(() -> evaluate(expr))
    +	 *             .map(result -> new CallToolResult("Result: " + result));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a map of tool arguments. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction, Mono> call) { + + static AsyncToolSpecification fromSync(SyncToolSpecification tool) { + // FIXME: This is temporary, proper validation should be implemented + if (tool == null) { + return null; + } + return new AsyncToolSpecification(tool.tool(), + (exchange, map) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.AsyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) ->
    +	 *         Mono.fromSupplier(() -> readFile(request.getPath()))
    +	 *             .map(ReadResourceResult::new)
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), + (exchange, req) -> Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.AsyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return Mono.just(new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         ));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), + (exchange, req) -> Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.SyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return new CallToolResult("Result: " + evaluate(expr));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a map of arguments passed to the tool. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> call) { + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.SyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return new ReadResourceResult(content);
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.SyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         );
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + // --------------------------------------- + // Deprecated registrations + // --------------------------------------- + + /** + * Registration of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool registration:

    {@code
    +	 * new McpServerFeatures.AsyncToolRegistration(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     args -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return Mono.just(new CallToolResult("Result: " + evaluate(expr)));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncToolSpecification}. + */ + @Deprecated + public record AsyncToolRegistration(McpSchema.Tool tool, + Function, Mono> call) { + + static AsyncToolRegistration fromSync(SyncToolRegistration tool) { + // FIXME: This is temporary, proper validation should be implemented + if (tool == null) { + return null; + } + return new AsyncToolRegistration(tool.tool(), + map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); + } + + public AsyncToolSpecification toSpecification() { + return new AsyncToolSpecification(tool(), (exchange, map) -> call.apply(map)); + } + } + + /** + * Registration of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource registration:

    {@code
    +	 * new McpServerFeatures.AsyncResourceRegistration(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     request -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return Mono.just(new ReadResourceResult(content));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncResourceSpecification}. + */ + @Deprecated + public record AsyncResourceRegistration(McpSchema.Resource resource, + Function> readHandler) { + + static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceRegistration(resource.resource(), + req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) + .subscribeOn(Schedulers.boundedElastic())); + } + + public AsyncResourceSpecification toSpecification() { + return new AsyncResourceSpecification(resource(), (exchange, request) -> readHandler.apply(request)); + } + } + + /** + * Registration of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt registration:

    {@code
    +	 * new McpServerFeatures.AsyncPromptRegistration(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     request -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return Mono.just(new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         ));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncPromptSpecification}. + */ + @Deprecated + public record AsyncPromptRegistration(McpSchema.Prompt prompt, + Function> promptHandler) { + + static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptRegistration(prompt.prompt(), + req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) + .subscribeOn(Schedulers.boundedElastic())); + } + + public AsyncPromptSpecification toSpecification() { + return new AsyncPromptSpecification(prompt(), (exchange, request) -> promptHandler.apply(request)); + } + } + + /** + * Registration of a tool with its synchronous handler function. Tools are the primary + * way for MCP servers to expose functionality to AI models. Each tool represents a + * specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool registration:

    {@code
    +	 * new McpServerFeatures.SyncToolRegistration(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     args -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return new CallToolResult("Result: " + evaluate(expr));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncToolSpecification}. + */ + @Deprecated + public record SyncToolRegistration(McpSchema.Tool tool, + Function, McpSchema.CallToolResult> call) { + public SyncToolSpecification toSpecification() { + return new SyncToolSpecification(tool, (exchange, map) -> call.apply(map)); + } + } + + /** + * Registration of a resource with its synchronous handler function. Resources provide + * context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource registration:

    {@code
    +	 * new McpServerFeatures.SyncResourceRegistration(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     request -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return new ReadResourceResult(content);
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncResourceSpecification}. + */ + @Deprecated + public record SyncResourceRegistration(McpSchema.Resource resource, + Function readHandler) { + public SyncResourceSpecification toSpecification() { + return new SyncResourceSpecification(resource, (exchange, request) -> readHandler.apply(request)); + } + } + + /** + * Registration of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt registration:

    {@code
    +	 * new McpServerFeatures.SyncPromptRegistration(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     request -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         );
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncPromptSpecification}. + */ + @Deprecated + public record SyncPromptRegistration(McpSchema.Prompt prompt, + Function promptHandler) { + public SyncPromptSpecification toSpecification() { + return new SyncPromptSpecification(prompt, (exchange, request) -> promptHandler.apply(request)); + } + } + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java similarity index 79% rename from mcp/src/main/java/org/springframework/ai/mcp/server/McpSyncServer.java rename to mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 004f7a3d0..60662d98d 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -1,29 +1,14 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; -import org.springframework.ai.mcp.server.McpServer.PromptRegistration; -import org.springframework.ai.mcp.server.McpServer.ResourceRegistration; -import org.springframework.ai.mcp.server.McpServer.ToolRegistration; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.LoggingMessageNotification; -import org.springframework.ai.mcp.util.Assert; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.util.Assert; /** * A synchronous implementation of the Model Context Protocol (MCP) server that wraps @@ -83,7 +68,10 @@ public McpSyncServer(McpAsyncServer asyncServer) { /** * Retrieves the list of all roots provided by the client. * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots()}. */ + @Deprecated public McpSchema.ListRootsResult listRoots() { return this.listRoots(null); } @@ -92,7 +80,10 @@ public McpSchema.ListRootsResult listRoots() { * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots(String)}. */ + @Deprecated public McpSchema.ListRootsResult listRoots(String cursor) { return this.asyncServer.listRoots(cursor).block(); } @@ -100,7 +91,10 @@ public McpSchema.ListRootsResult listRoots(String cursor) { /** * Add a new tool handler. * @param toolHandler The tool handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.SyncToolSpecification)}. */ + @Deprecated public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { this.asyncServer.addTool(McpServerFeatures.AsyncToolRegistration.fromSync(toolHandler)).block(); } @@ -108,11 +102,9 @@ public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { /** * Add a new tool handler. * @param toolHandler The tool handler to add - * @deprecated Use {@link #addTool(McpServerFeatures.SyncToolRegistration)}. */ - @Deprecated - public void addTool(ToolRegistration toolHandler) { - this.asyncServer.addTool(toolHandler).block(); + public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { + this.asyncServer.addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler)).block(); } /** @@ -126,7 +118,10 @@ public void removeTool(String toolName) { /** * Add a new resource handler. * @param resourceHandler The resource handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.SyncResourceSpecification)}. */ + @Deprecated public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandler) { this.asyncServer.addResource(McpServerFeatures.AsyncResourceRegistration.fromSync(resourceHandler)).block(); } @@ -134,11 +129,9 @@ public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandl /** * Add a new resource handler. * @param resourceHandler The resource handler to add - * @deprecated Use {@link #addResource(McpServerFeatures.SyncResourceRegistration)}. */ - @Deprecated - public void addResource(ResourceRegistration resourceHandler) { - this.asyncServer.addResource(resourceHandler).block(); + public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { + this.asyncServer.addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler)).block(); } /** @@ -152,19 +145,20 @@ public void removeResource(String resourceUri) { /** * Add a new prompt handler. * @param promptRegistration The prompt registration to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.SyncPromptSpecification)}. */ + @Deprecated public void addPrompt(McpServerFeatures.SyncPromptRegistration promptRegistration) { this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptRegistration.fromSync(promptRegistration)).block(); } /** * Add a new prompt handler. - * @param promptRegistration The prompt registration to add - * @deprecated Use {@link #addPrompt(McpServerFeatures.SyncPromptRegistration)}. + * @param promptSpecification The prompt specification to add */ - @Deprecated - public void addPrompt(PromptRegistration promptRegistration) { - this.asyncServer.addPrompt(promptRegistration).block(); + public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification)).block(); } /** @@ -201,7 +195,10 @@ public McpSchema.Implementation getServerInfo() { /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientCapabilities()}. */ + @Deprecated public ClientCapabilities getClientCapabilities() { return this.asyncServer.getClientCapabilities(); } @@ -209,7 +206,10 @@ public ClientCapabilities getClientCapabilities() { /** * Get the client implementation information. * @return The client implementation details + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientInfo()}. */ + @Deprecated public McpSchema.Implementation getClientInfo() { return this.asyncServer.getClientInfo(); } @@ -282,7 +282,10 @@ public McpAsyncServer getAsyncServer() { * @see Sampling * Specification + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. */ + @Deprecated public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { return this.asyncServer.createMessage(createMessageRequest).block(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java new file mode 100644 index 000000000..f121db552 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -0,0 +1,78 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ +public class McpSyncServerExchange { + + private final McpAsyncServerExchange exchange; + + /** + * Create a new synchronous exchange with the client using the provided asynchronous + * implementation as a delegate. + * @param exchange The asynchronous exchange to delegate to. + */ + public McpSyncServerExchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.exchange.getClientInfo(); + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A result containing the details of the sampling response + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessage(createMessageRequest).block(); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return The list of roots result. + */ + public McpSchema.ListRootsResult listRoots() { + return this.exchange.listRoots().block(); + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of roots result + */ + public McpSchema.ListRootsResult listRoots(String cursor) { + return this.exchange.listRoots(cursor).block(); + } + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java similarity index 94% rename from mcp/src/main/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransport.java rename to mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java index d5ae27bd2..fa5dcf1c1 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.BufferedReader; import java.io.IOException; @@ -26,6 +14,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpTransport; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -36,15 +27,14 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - /** * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport * specification. This implementation provides similar functionality to * WebFluxSseServerTransport but uses the traditional Servlet API instead of WebFlux. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link HttpServletSseServerTransportProvider}. + * *

    * The transport handles two types of endpoints: *

      @@ -61,7 +51,6 @@ *
    • Graceful shutdown support
    • *
    • Error handling and response formatting
    • *
    - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport @@ -69,6 +58,7 @@ */ @WebServlet(asyncSupported = true) +@Deprecated public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java new file mode 100644 index 000000000..152462b1d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -0,0 +1,432 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport + * specification. This implementation provides similar functionality to + * WebFluxSseServerTransportProvider but uses the traditional Servlet API instead of + * WebFlux. + * + *

    + * The transport handles two types of endpoints: + *

      + *
    • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client + * events
    • + *
    • Message endpoint (configurable) - Handles client-to-server message requests
    • + *
    + * + *

    + * Features: + *

      + *
    • Asynchronous message handling using Servlet 6.0 async support
    • + *
    • Session management for multiple client connections
    • + *
    • Graceful shutdown support
    • + *
    • Error handling and response formatting
    • + *
    + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see HttpServlet + */ + +@WebServlet(asyncSupported = true) +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** Default endpoint path for SSE connections */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** Event type for regular messages */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** Event type for endpoint information */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling client messages */ + private final String messageEndpoint; + + /** The endpoint path for handling SSE connections */ + private final String sseEndpoint; + + /** Map of active client sessions, keyed by session ID */ + private final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Session factory for creating new sessions */ + private McpServerSession.Factory sessionFactory; + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, + String sseEndpoint) { + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + /** + * Sets the session factory for creating new sessions. + * @param sessionFactory The session factory to use + */ + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Handles GET requests to establish SSE connections. + *

    + * This method sets up a new SSE connection when a client connects to the SSE + * endpoint. It configures the response headers for SSE, creates a new session, and + * sends the initial endpoint information to the client. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String pathInfo = request.getPathInfo(); + if (!sseEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + response.setContentType("text/event-stream"); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + PrintWriter writer = response.getWriter(); + + // Create a new session transport + HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, + writer); + + // Create a new session using the session factory + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + // Send initial endpoint event + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + } + + /** + * Handles POST requests for client messages. + *

    + * This method processes incoming messages from clients, routes them through the + * session handler, and sends back the appropriate response. It handles error cases + * and formats error responses according to the MCP specification. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String pathInfo = request.getPathInfo(); + if (!messageEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + // Get the session ID from the request parameter + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + // Get the session from the sessions map + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for Servlet compatibility + + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Initiates a graceful shutdown of the transport. + *

    + * This method marks the transport as closing and closes all active client sessions. + * New connection attempts will be rejected during shutdown. + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + } + + /** + * Sends an SSE event to a client. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

    + * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpServerTransport for HttpServlet SSE sessions. This class + * handles the transport-level communication for a specific client session. + */ + private class HttpServletMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + asyncContext.complete(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java similarity index 82% rename from mcp/src/main/java/org/springframework/ai/mcp/server/transport/StdioServerTransport.java rename to mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java index e9b88d247..78264ca32 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/transport/StdioServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.BufferedReader; import java.io.IOException; @@ -22,12 +10,15 @@ import java.io.InputStreamReader; import java.io.OutputStream; import java.nio.charset.StandardCharsets; -import java.time.Duration; import java.util.concurrent.Executors; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -36,18 +27,16 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage; -import org.springframework.ai.mcp.spec.ServerMcpTransport; -import org.springframework.ai.mcp.util.Assert; - /** * Implementation of the MCP Stdio transport for servers that communicates using standard * input/output streams. Messages are exchanged as newline-delimited JSON-RPC messages * over stdin/stdout, with errors and debug information sent to stderr. * * @author Christian Tzolov + * @deprecated This method will be removed in 0.9.0. Use + * {@link io.modelcontextprotocol.server.transport.StdioServerTransportProvider} instead. */ +@Deprecated public class StdioServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(StdioServerTransport.class); @@ -118,10 +107,10 @@ private void handleIncomingMessages(Function, Mono Mono.just(message) .transform(inboundMessageHandler) .contextWrite(ctx -> ctx.put("observation", "myObservation"))) - .doOnComplete(() -> { + .doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion this.outboundSink.tryEmitComplete(); this.inboundScheduler.dispose(); - this.outboundScheduler.dispose(); }) .subscribe(); } @@ -155,6 +144,8 @@ private void startInboundProcessing() { break; } + logger.debug("Received JSON message: {}", line); + try { JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { @@ -221,13 +212,13 @@ else if (isClosing) { }) .doOnComplete(() -> { isClosing = true; - outboundSink.tryEmitComplete(); + outboundScheduler.dispose(); }) .doOnError(e -> { if (!isClosing) { logger.error("Error in outbound processing", e); isClosing = true; - outboundSink.tryEmitComplete(); + outboundScheduler.dispose(); } }) .map(msg -> (JSONRPCMessage) msg); @@ -237,26 +228,15 @@ else if (isClosing) { @Override public Mono closeGracefully() { - - return Mono.fromRunnable(() -> { + return Mono.defer(() -> { isClosing = true; logger.debug("Initiating graceful shutdown"); - }).then(Mono.defer(() -> { - // First complete the sinks to stop processing + // Completing the inbound causes the outbound to be completed as well, so + // we only close the inbound. inboundSink.tryEmitComplete(); - outboundSink.tryEmitComplete(); - return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromRunnable(() -> { - try { - // Dispose schedulers with longer timeout - inboundScheduler.dispose(); - outboundScheduler.dispose(); - logger.info("Graceful shutdown completed"); - } - catch (Exception e) { - logger.error("Error during graceful shutdown", e); - } - })).then().subscribeOn(Schedulers.boundedElastic()); + logger.debug("Graceful shutdown complete"); + return Mono.empty(); + }).subscribeOn(Schedulers.boundedElastic()); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java new file mode 100644 index 000000000..6a7d29039 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * Implementation of the MCP Stdio transport provider for servers that communicates using + * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC + * messages over stdin/stdout, with errors and debug information sent to stderr. + * + * @author Christian Tzolov + */ +public class StdioServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private final InputStream inputStream; + + private final OutputStream outputStream; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + /** + * Creates a new StdioServerTransportProvider with a default ObjectMapper and System + * streams. + */ + public StdioServerTransportProvider() { + this(new ObjectMapper()); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * System streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public StdioServerTransportProvider(ObjectMapper objectMapper) { + this(objectMapper, System.in, System.out); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param inputStream The input stream to read from + * @param outputStream The output stream to write to + */ + public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(inputStream, "The InputStream can not be null"); + Assert.notNull(outputStream, "The OutputStream can not be null"); + + this.objectMapper = objectMapper; + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Create a single session for the stdio connection + this.session = sessionFactory.create(new StdioMcpSessionTransport()); + } + + @Override + public Mono notifyClients(String method, Map params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class StdioMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling inbound messages */ + private Scheduler inboundScheduler; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public StdioMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-outbound"); + + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + this.inboundScheduler.schedule(() -> { + inboundReady.tryEmitValue(null); + BufferedReader reader = null; + try { + reader = new BufferedReader(new InputStreamReader(inputStream)); + while (!isClosing.get()) { + try { + String line = reader.readLine(); + if (line == null || isClosing.get()) { + break; + } + + logger.debug("Received JSON message: {}", line); + + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + line); + if (!this.inboundSink.tryEmitNext(message).isSuccess()) { + // logIfNotClosing("Failed to enqueue message"); + break; + } + + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + break; + } + } + catch (IOException e) { + logIfNotClosing("Error reading from stdin", e); + break; + } + } + } + catch (Exception e) { + logIfNotClosing("Error in inbound processing", e); + } + finally { + isClosing.set(true); + if (session != null) { + session.close(); + } + inboundSink.tryEmitComplete(); + } + }); + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + String jsonMessage = objectMapper.writeValueAsString(message); + // Escape any embedded newlines in the JSON message as per spec + jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); + + synchronized (outputStream) { + outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + } + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java new file mode 100644 index 000000000..8464b6ae7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java @@ -0,0 +1,15 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +/** + * Marker interface for the client-side MCP transport. + * + * @author Christian Tzolov + * @deprecated This class will be removed in 0.9.0. Use {@link McpClientTransport}. + */ +@Deprecated +public interface ClientMcpTransport extends McpTransport { + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java similarity index 92% rename from mcp/src/main/java/org/springframework/ai/mcp/spec/DefaultMcpSession.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index fffd710b2..83de4c094 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.spec; +package io.modelcontextprotocol.spec; import java.time.Duration; import java.util.Map; @@ -23,14 +11,13 @@ import java.util.concurrent.atomic.AtomicLong; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; -import org.springframework.ai.mcp.util.Assert; - /** * Default implementation of the MCP (Model Context Protocol) session that manages * bidirectional JSON-RPC communication between clients and servers. This implementation @@ -47,7 +34,10 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @deprecated This method will be removed in 0.9.0. Use {@link McpClientSession} instead */ +@Deprecated + public class DefaultMcpSession implements McpSession { /** Logger for this class */ @@ -137,7 +127,7 @@ public DefaultMcpSession(Duration requestTimeout, McpTransport transport, // consumer this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { if (message instanceof McpSchema.JSONRPCResponse response) { - logger.info("Received Response: {}", response); + logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); if (sink == null) { logger.warn("Unexpected response for unkown id {}", response.id()); @@ -147,7 +137,7 @@ public DefaultMcpSession(Duration requestTimeout, McpTransport transport, } } else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.info("Received request: {}", request); + logger.debug("Received request: {}", request); handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), @@ -157,7 +147,7 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { }); } else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.info("Received notification: {}", notification); + logger.debug("Received notification: {}", notification); handleIncomingNotification(notification).subscribe(null, error -> logger.error("Error handling notification: {}", error.getMessage())); } @@ -283,8 +273,10 @@ public Mono sendNotification(String method, Map params) { */ @Override public Mono closeGracefully() { - this.connection.dispose(); - return transport.closeGracefully(); + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java new file mode 100644 index 000000000..6657e3622 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Default implementation of the MCP (Model Context Protocol) session that manages + * bidirectional JSON-RPC communication between clients and servers. This implementation + * follows the MCP specification for message exchange and transport handling. + * + *

    + * The session manages: + *

      + *
    • Request/response handling with unique message IDs
    • + *
    • Notification processing
    • + *
    • Message timeout management
    • + *
    • Transport layer abstraction
    • + *
    + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public class McpClientSession implements McpSession { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); + + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + + /** Transport layer implementation for message exchange */ + private final McpTransport transport; + + /** Map of pending responses keyed by request ID */ + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + /** Map of request handlers keyed by method name */ + private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); + + /** Map of notification handlers keyed by method name */ + private final ConcurrentHashMap notificationHandlers = new ConcurrentHashMap<>(); + + /** Session-specific prefix for request IDs */ + private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); + + /** Atomic counter for generating unique request IDs */ + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Disposable connection; + + /** + * Functional interface for handling incoming JSON-RPC requests. Implementations + * should process the request parameters and return a response. + * + * @param Response type + */ + @FunctionalInterface + public interface RequestHandler { + + /** + * Handles an incoming request with the given parameters. + * @param params The request parameters + * @return A Mono containing the response object + */ + Mono handle(Object params); + + } + + /** + * Functional interface for handling incoming JSON-RPC notifications. Implementations + * should process the notification parameters without returning a response. + */ + @FunctionalInterface + public interface NotificationHandler { + + /** + * Handles an incoming notification with the given parameters. + * @param params The notification parameters + * @return A Mono that completes when the notification is processed + */ + Mono handle(Object params); + + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + */ + public McpClientSession(Duration requestTimeout, McpTransport transport, + Map> requestHandlers, Map notificationHandlers) { + + Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(transport, "The transport can not be null"); + Assert.notNull(requestHandlers, "The requestHandlers can not be null"); + Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); + + this.requestTimeout = requestTimeout; + this.transport = transport; + this.requestHandlers.putAll(requestHandlers); + this.notificationHandlers.putAll(notificationHandlers); + + // TODO: consider mono.transformDeferredContextual where the Context contains + // the + // Observation associated with the individual message - it can be used to + // create child Observation and emit it together with the message to the + // consumer + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unkown id {}", response.id()); + } + else { + sink.success(response); + } + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), + error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); + transport.sendMessage(errorResponse).subscribe(); + }); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification).subscribe(null, + error -> logger.error("Error handling notification: {}", error.getMessage())); + } + })).subscribe(); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + return handler.handle(request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + public static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return handler.handle(notification.params()); + }); + } + + /** + * Generates a unique request ID in a non-blocking way. Combines a session-specific + * prefix with an atomic counter to ensure uniqueness. + * @return A unique request ID string + */ + private String generateRequestId() { + return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + } + + /** + * Sends a JSON-RPC request and returns the response. + * @param The expected response type + * @param method The method name to call + * @param requestParams The request parameters + * @param typeRef Type reference for response deserialization + * @return A Mono containing the response + */ + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest) + // TODO: It's most efficient to create a dedicated Subscriber here + .subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + /** + * Sends a JSON-RPC notification. + * @param method The method name for the notification + * @param params The notification parameters + * @return A Mono that completes when the notification is sent + */ + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Closes the session gracefully, allowing pending operations to complete. + * @return A Mono that completes when the session is closed + */ + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); + } + + /** + * Closes the session immediately, potentially interrupting pending operations. + */ + @Override + public void close() { + this.connection.dispose(); + transport.close(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java new file mode 100644 index 000000000..458979651 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -0,0 +1,21 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +/** + * Marker interface for the client-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpClientTransport extends ClientMcpTransport { + + @Override + Mono connect(Function, Mono> handler); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java new file mode 100644 index 000000000..13e43240b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -0,0 +1,25 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; + +public class McpError extends RuntimeException { + + private JSONRPCError jsonRpcError; + + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.message()); + this.jsonRpcError = jsonRpcError; + } + + public McpError(Object error) { + super(error.toString()); + } + + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java similarity index 87% rename from mcp/src/main/java/org/springframework/ai/mcp/spec/McpSchema.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e107fb6bc..37d9e0c0a 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1,22 +1,11 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.spec; +package io.modelcontextprotocol.spec; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,6 +18,8 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Based on the JSON-RPC 2.0 @@ -40,6 +31,8 @@ */ public final class McpSchema { + private static final Logger logger = LoggerFactory.getLogger(McpSchema.class); + private McpSchema() { } @@ -156,6 +149,8 @@ public sealed interface Request public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) throws IOException { + logger.debug("Received JSON message: {}", jsonText); + var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); // Determine message type based on specific JSON structure @@ -182,6 +177,7 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati } @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCRequest( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -190,6 +186,7 @@ public record JSONRPCRequest( // @formatter:off } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -197,6 +194,7 @@ public record JSONRPCNotification( // @formatter:off } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCResponse( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @@ -204,6 +202,7 @@ public record JSONRPCResponse( // @formatter:off @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCError( @JsonProperty("code") int code, @JsonProperty("message") String message, @@ -215,6 +214,7 @@ public record JSONRPCError( // Initialization // --------------------------- @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, @@ -245,6 +245,7 @@ public record InitializeResult( // @formatter:off * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, @@ -259,7 +260,8 @@ public record ClientCapabilities( // @formatter:off * @param listChanged Whether the client would send notification about roots * has changed since the last time the server checked. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -309,6 +311,7 @@ public ClientCapabilities build() { }// @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ServerCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("logging") LoggingCapabilities logging, @@ -381,6 +384,7 @@ public ServerCapabilities build() { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record Implementation(// @formatter:off @JsonProperty("name") String name, @JsonProperty("version") String version) { @@ -419,6 +423,7 @@ public interface Annotated { * optional. It is a number between 0 and 1. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record Annotations( // @formatter:off @JsonProperty("audience") List audience, @JsonProperty("priority") Double priority) { @@ -464,6 +469,7 @@ public record Resource( // @formatter:off * @see RFC 6570 */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ResourceTemplate( // @formatter:off @JsonProperty("uriTemplate") String uriTemplate, @JsonProperty("name") String name, @@ -487,6 +493,7 @@ public record ListResourceTemplatesResult( // @formatter:off } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ReadResourceRequest( // @formatter:off @JsonProperty("uri") String uri){ } // @formatter:on @@ -505,11 +512,13 @@ public record ReadResourceResult( // @formatter:off * it is up to the server how to interpret it. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record SubscribeRequest( // @formatter:off @JsonProperty("uri") String uri){ } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record UnsubscribeRequest( // @formatter:off @JsonProperty("uri") String uri){ } // @formatter:on @@ -580,6 +589,7 @@ public record BlobResourceContents( // @formatter:off * @param arguments A list of arguments to use for templating the prompt. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record Prompt( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @@ -594,6 +604,7 @@ public record Prompt( // @formatter:off * @param required Whether this argument must be provided. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record PromptArgument( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @@ -610,6 +621,7 @@ public record PromptArgument( // @formatter:off * @param content The content of the message of type {@link Content}. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record PromptMessage( // @formatter:off @JsonProperty("role") Role role, @JsonProperty("content") Content content) { @@ -636,6 +648,7 @@ public record ListPromptsResult( // @formatter:off * @param arguments Arguments to use for templating the prompt. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record GetPromptRequest(// @formatter:off @JsonProperty("name") String name, @JsonProperty("arguments") Map arguments) implements Request { @@ -673,9 +686,9 @@ public record ListToolsResult( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - record JsonSchema( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("properties") Map properties, + public record JsonSchema( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("properties") Map properties, @JsonProperty("required") List required, @JsonProperty("additionalProperties") Boolean additionalProperties) { } // @formatter:on @@ -694,6 +707,7 @@ record JsonSchema( // @formatter:off * arguments before sending them to the server. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @@ -748,18 +762,67 @@ public record CallToolResult( // @formatter:off // Sampling Interfaces // --------------------------- @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ModelPreferences(// @formatter:off - @JsonProperty("hints") List hints, - @JsonProperty("costPriority") Double costPriority, - @JsonProperty("speedPriority") Double speedPriority, - @JsonProperty("intelligencePriority") Double intelligencePriority) { - } // @formatter:on + @JsonProperty("hints") List hints, + @JsonProperty("costPriority") Double costPriority, + @JsonProperty("speedPriority") Double speedPriority, + @JsonProperty("intelligencePriority") Double intelligencePriority) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List hints; + private Double costPriority; + private Double speedPriority; + private Double intelligencePriority; + + public Builder hints(List hints) { + this.hints = hints; + return this; + } + + public Builder addHint(String name) { + if (this.hints == null) { + this.hints = new ArrayList<>(); + } + this.hints.add(new ModelHint(name)); + return this; + } + + public Builder costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + public Builder speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + public Builder intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + public ModelPreferences build() { + return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority); + } + } +} // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ModelHint(@JsonProperty("name") String name) { + public static ModelHint of(String name) { + return new ModelHint(name); + } } @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record SamplingMessage(// @formatter:off @JsonProperty("role") Role role, @JsonProperty("content") Content content) { @@ -767,6 +830,7 @@ public record SamplingMessage(// @formatter:off // Sampling and Message Creation @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record CreateMessageRequest(// @formatter:off @JsonProperty("messages") List messages, @JsonProperty("modelPreferences") ModelPreferences modelPreferences, @@ -779,8 +843,68 @@ public record CreateMessageRequest(// @formatter:off public enum ContextInclusionStrategy { @JsonProperty("none") NONE, - @JsonProperty("this_server") THIS_SERVER, - @JsonProperty("all_server") ALL_SERVERS + @JsonProperty("thisServer") THIS_SERVER, + @JsonProperty("allServers") ALL_SERVERS + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List messages; + private ModelPreferences modelPreferences; + private String systemPrompt; + private ContextInclusionStrategy includeContext; + private Double temperature; + private int maxTokens; + private List stopSequences; + private Map metadata; + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder modelPreferences(ModelPreferences modelPreferences) { + this.modelPreferences = modelPreferences; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public Builder includeContext(ContextInclusionStrategy includeContext) { + this.includeContext = includeContext; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public CreateMessageRequest build() { + return new CreateMessageRequest(messages, modelPreferences, systemPrompt, + includeContext, temperature, maxTokens, stopSequences, metadata); + } } }// @formatter:on @@ -793,9 +917,9 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("stopReason") StopReason stopReason) { public enum StopReason { - @JsonProperty("end_turn") END_TURN, - @JsonProperty("stop_sequence") STOP_SEQUENCE, - @JsonProperty("max_tokens") MAX_TOKENS + @JsonProperty("endTurn") END_TURN, + @JsonProperty("stopSequence") STOP_SEQUENCE, + @JsonProperty("maxTokens") MAX_TOKENS } public static Builder builder() { @@ -843,6 +967,7 @@ public CreateMessageResult build() { // Pagination Interfaces // --------------------------- @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record PaginatedRequest(@JsonProperty("cursor") String cursor) { } @@ -854,6 +979,7 @@ public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { // --------------------------- // Progress and Logging // --------------------------- + @JsonIgnoreProperties(ignoreUnknown = true) public record ProgressNotification(// @formatter:off @JsonProperty("progressToken") String progressToken, @JsonProperty("progress") double progress, @@ -870,6 +996,7 @@ public record ProgressNotification(// @formatter:off * @param logger The logger that generated the message. * @param data JSON-serializable logging data. */ + @JsonIgnoreProperties(ignoreUnknown = true) public record LoggingMessageNotification(// @formatter:off @JsonProperty("level") LoggingLevel level, @JsonProperty("logger") String logger, @@ -970,61 +1097,48 @@ public record CompleteCompletion(// @formatter:off @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource") }) public sealed interface Content permits TextContent, ImageContent, EmbeddedResource { - String type(); + default String type() { + if (this instanceof TextContent) { + return "text"; + } + else if (this instanceof ImageContent) { + return "image"; + } + else if (this instanceof EmbeddedResource) { + return "resource"; + } + throw new IllegalArgumentException("Unknown content type: " + this); + } } @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record TextContent( // @formatter:off @JsonProperty("audience") List audience, @JsonProperty("priority") Double priority, - @JsonProperty("type") String type, @JsonProperty("text") String text) implements Content { // @formatter:on - public TextContent { - type = "text"; - } - - public String type() { - return type; - } - public TextContent(String content) { - this(null, null, "text", content); + this(null, null, content); } } @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record ImageContent( // @formatter:off @JsonProperty("audience") List audience, @JsonProperty("priority") Double priority, - @JsonProperty("type") String type, @JsonProperty("data") String data, @JsonProperty("mimeType") String mimeType) implements Content { // @formatter:on - - public ImageContent { - type = "image"; - } - - public String type() { - return type; - } } @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record EmbeddedResource( // @formatter:off @JsonProperty("audience") List audience, @JsonProperty("priority") Double priority, - @JsonProperty("type") String type, @JsonProperty("resource") ResourceContents resource) implements Content { // @formatter:on - - public EmbeddedResource { - type = "resource"; - } - - public String type() { - return type; - } } // --------------------------- @@ -1041,6 +1155,7 @@ public String type() { * for referencing the root in other parts of the application. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) public record Root( // @formatter:off @JsonProperty("uri") String uri, @JsonProperty("name") String name) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java new file mode 100644 index 000000000..bcdf22486 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -0,0 +1,354 @@ +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +/** + * Represents a Model Control Protocol (MCP) session on the server side. It manages + * bidirectional JSON-RPC communication with the client. + */ +public class McpServerSession implements McpSession { + + private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final String id; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitRequestHandler initRequestHandler; + + private final InitNotificationHandler initNotificationHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final McpServerTransport transport; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param initNotificationHandler called when a + * {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received. + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.transport = transport; + this.initRequestHandler = initHandler; + this.initNotificationHandler = initNotificationHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Retrieve the session id. + * @return session id + */ + public String getId() { + return this.id; + } + + /** + * Called upon successful initialization sequence between the client and the server + * with the client capabilities and information. + * + * Initialization + * Spec + * @param clientCapabilities the capabilities the connected client provides + * @param clientInfo the information about the connected client + */ + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { + // TODO handle errors for communication to without initialization happening + // first + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + // TODO: Should the error go to SSE or back as POST return? + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // TODO handle errors for communication to without initialization + // happening first + logger.debug("Received notification: {}", notification); + // TODO: in case of error, should the POST request be signalled? + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + // TODO handle situation where already initialized! + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeReference() { + }); + + this.state.lazySet(STATE_INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + resultMono = this.initRequestHandler.handle(initializeRequest); + } + else { + // TODO handle errors for communication to this session without + // initialization happening first + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + } + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(STATE_INITIALIZED); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + return this.initNotificationHandler.handle(); + } + + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.transport.close(); + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + public interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + */ + public interface NotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + public interface RequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + McpServerSession create(McpServerTransport sessionTransport); + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java new file mode 100644 index 000000000..632b8cee6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.spec; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransport extends McpTransport { + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java new file mode 100644 index 000000000..dba8cc43f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -0,0 +1,66 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

    + * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

    + * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProvider { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params a map of parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Map params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java similarity index 66% rename from mcp/src/main/java/org/springframework/ai/mcp/spec/McpSession.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 533baa8dc..b97c3ccc4 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.spec; +package io.modelcontextprotocol.spec; import java.util.Map; @@ -38,14 +26,15 @@ public interface McpSession { /** - * Sends a request to the model server and expects a response of type T. + * Sends a request to the model counterparty and expects a response of type T. * *

    * This method handles the request-response pattern where a response is expected from - * the server. The response type is determined by the provided TypeReference. + * the client or server. The response type is determined by the provided + * TypeReference. *

    * @param the type of the expected response - * @param method the name of the method to be called on the server + * @param method the name of the method to be called on the counterparty * @param requestParams the parameters to be sent with the request * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received @@ -53,11 +42,11 @@ public interface McpSession { Mono sendRequest(String method, Object requestParams, TypeReference typeRef); /** - * Sends a notification to the model server without parameters. + * Sends a notification to the model client or server without parameters. * *

    * This method implements the notification pattern where no response is expected from - * the server. It's useful for fire-and-forget scenarios. + * the counterparty. It's useful for fire-and-forget scenarios. *

    * @param method the name of the notification method to be called on the server * @return a Mono that completes when the notification has been sent @@ -67,13 +56,13 @@ default Mono sendNotification(String method) { } /** - * Sends a notification to the model server with parameters. + * Sends a notification to the model client or server with parameters. * *

    * Similar to {@link #sendNotification(String)} but allows sending additional * parameters with the notification. *

    - * @param method the name of the notification method to be called on the server + * @param method the name of the notification method to be sent to the counterparty * @param params a map of parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java similarity index 76% rename from mcp/src/main/java/org/springframework/ai/mcp/spec/McpTransport.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 48c696561..f698d8789 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -1,28 +1,15 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.spec; +package io.modelcontextprotocol.spec; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import reactor.core.publisher.Mono; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCMessage; - /** * Defines the asynchronous transport layer for the Model Context Protocol (MCP). * @@ -59,8 +46,13 @@ public interface McpTransport { * This method should be called before any message exchange can occur. It sets up the * necessary resources and establishes the connection to the server. *

    + * @deprecated This is only relevant for client-side transports and will be removed + * from this interface in 0.9.0. */ - Mono connect(Function, Mono> handler); + @Deprecated + default Mono connect(Function, Mono> handler) { + return Mono.empty(); + } /** * Closes the transport connection and releases any associated resources. @@ -82,7 +74,7 @@ default void close() { Mono closeGracefully(); /** - * Sends a message to the server asynchronously. + * Sends a message to the peer asynchronously. * *

    * This method handles the transmission of messages to the server in an asynchronous diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java new file mode 100644 index 000000000..704daee0f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java @@ -0,0 +1,15 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @deprecated This class will be removed in 0.9.0. Use {@link McpServerTransport}. + */ +@Deprecated +public interface ServerMcpTransport extends McpTransport { + +} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/util/Assert.java b/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java similarity index 80% rename from mcp/src/main/java/org/springframework/ai/mcp/util/Assert.java rename to mcp/src/main/java/io/modelcontextprotocol/util/Assert.java index 0dc1fc103..d68188c6f 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/util/Assert.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.util; +package io.modelcontextprotocol.util; import java.util.Collection; diff --git a/mcp/src/main/java/org/springframework/ai/mcp/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java similarity index 71% rename from mcp/src/main/java/org/springframework/ai/mcp/util/Utils.java rename to mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 622ad269a..0f799ca0f 100644 --- a/mcp/src/main/java/org/springframework/ai/mcp/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.util; +package io.modelcontextprotocol.util; import java.util.Collection; import java.util.Map; diff --git a/mcp/src/main/java/org/springframework/ai/mcp/client/McpAsyncClient.java b/mcp/src/main/java/org/springframework/ai/mcp/client/McpAsyncClient.java deleted file mode 100644 index 7f6ccc122..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/client/McpAsyncClient.java +++ /dev/null @@ -1,933 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.mcp.client; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.DefaultMcpSession; -import org.springframework.ai.mcp.spec.DefaultMcpSession.NotificationHandler; -import org.springframework.ai.mcp.spec.DefaultMcpSession.RequestHandler; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities.RootCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities.Sampling; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptRequest; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.Implementation; -import org.springframework.ai.mcp.spec.McpSchema.ListPromptsResult; -import org.springframework.ai.mcp.spec.McpSchema.LoggingLevel; -import org.springframework.ai.mcp.spec.McpSchema.LoggingMessageNotification; -import org.springframework.ai.mcp.spec.McpSchema.PaginatedRequest; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.util.Assert; -import org.springframework.ai.mcp.util.Utils; - -/** - * The Model Context Protocol (MCP) client implementation that provides asynchronous - * communication with MCP servers using Project Reactor's Mono and Flux types. - * - *

    - * This client implements the MCP specification, enabling AI models to interact with - * external tools and resources through a standardized interface. Key features include: - *

      - *
    • Asynchronous communication using reactive programming patterns - *
    • Tool discovery and invocation for server-provided functionality - *
    • Resource access and management with URI-based addressing - *
    • Prompt template handling for standardized AI interactions - *
    • Real-time notifications for tools, resources, and prompts changes - *
    • Structured logging with configurable severity levels - *
    • Message sampling for AI model interactions - *
    - * - *

    - * The client follows a lifecycle: - *

      - *
    1. Initialization - Establishes connection and negotiates capabilities - *
    2. Normal Operation - Handles requests and notifications - *
    3. Graceful Shutdown - Ensures clean connection termination - *
    - * - *

    - * This implementation uses Project Reactor for non-blocking operations, making it - * suitable for high-throughput scenarios and reactive applications. All operations return - * Mono or Flux types that can be composed into reactive pipelines. - * - * @author Dariusz Jędrzejczyk - * @author Christian Tzolov - * @see McpClient - * @see McpSchema - * @see DefaultMcpSession - */ -public class McpAsyncClient { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { - }; - - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; - - /** - * Client capabilities. - */ - private final McpSchema.ClientCapabilities clientCapabilities; - - /** - * Client implementation information. - */ - private final McpSchema.Implementation clientInfo; - - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - - /** - * Roots define the boundaries of where servers can operate within the filesystem, - * allowing them to understand which directories and files they have access to. - * Servers can request the list of roots from supporting clients and receive - * notifications when that list changes. - */ - private final ConcurrentHashMap roots; - - /** - * MCP provides a standardized way for servers to request LLM sampling ("completions" - * or "generations") from language models via clients. This flow allows clients to - * maintain control over model access, selection, and permissions while enabling - * servers to leverage AI capabilities—with no server API keys necessary. Servers can - * request text or image-based interactions and optionally include context from MCP - * servers in their prompts. - */ - private Function> samplingHandler; - - /** - * Client transport implementation. - */ - private final McpTransport transport; - - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - /** - * Create a new McpAsyncClient with the given transport and session request-response - * timeout. - * @param transport the transport to use. - * @param requestTimeout the session request-response timeout. - * @param features the MCP Client supported features. - */ - McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, McpClientFeatures.Async features) { - - Assert.notNull(transport, "Transport must not be null"); - Assert.notNull(requestTimeout, "Request timeout must not be null"); - - this.clientInfo = features.clientInfo(); - this.clientCapabilities = features.clientCapabilities(); - this.transport = transport; - this.roots = new ConcurrentHashMap<>(features.roots()); - - // Request Handlers - Map> requestHandlers = new HashMap<>(); - - // Roots List Request Handler - if (this.clientCapabilities.roots() != null) { - requestHandlers.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler()); - } - - // Sampling Handler - if (this.clientCapabilities.sampling() != null) { - if (features.samplingHandler() == null) { - throw new McpError("Sampling handler must not be null when client capabilities include sampling"); - } - this.samplingHandler = features.samplingHandler(); - requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); - } - - // Notification Handlers - Map notificationHandlers = new HashMap<>(); - - // Tools Change Notification - List, Mono>> toolsChangeConsumersFinal = new ArrayList<>(); - toolsChangeConsumersFinal - .add((notification) -> Mono.fromRunnable(() -> logger.info("Tools changed: {}", notification))); - - if (!Utils.isEmpty(features.toolsChangeConsumers())) { - toolsChangeConsumersFinal.addAll(features.toolsChangeConsumers()); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, - asyncToolsChangeNotificationHandler(toolsChangeConsumersFinal)); - - // Resources Change Notification - List, Mono>> resourcesChangeConsumersFinal = new ArrayList<>(); - resourcesChangeConsumersFinal - .add((notification) -> Mono.fromRunnable(() -> logger.info("Resources changed: {}", notification))); - - if (!Utils.isEmpty(features.resourcesChangeConsumers())) { - resourcesChangeConsumersFinal.addAll(features.resourcesChangeConsumers()); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, - asyncResourcesChangeNotificationHandler(resourcesChangeConsumersFinal)); - - // Prompts Change Notification - List, Mono>> promptsChangeConsumersFinal = new ArrayList<>(); - promptsChangeConsumersFinal - .add((notification) -> Mono.fromRunnable(() -> logger.info("Prompts changed: {}", notification))); - if (!Utils.isEmpty(features.promptsChangeConsumers())) { - promptsChangeConsumersFinal.addAll(features.promptsChangeConsumers()); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, - asyncPromptsChangeNotificationHandler(promptsChangeConsumersFinal)); - - // Utility Logging Notification - List>> loggingConsumersFinal = new ArrayList<>(); - loggingConsumersFinal.add((notification) -> Mono.fromRunnable(() -> logger.info("Logging: {}", notification))); - if (!Utils.isEmpty(features.loggingConsumers())) { - loggingConsumersFinal.addAll(features.loggingConsumers()); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, - asyncLoggingNotificationHandler(loggingConsumersFinal)); - - this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers); - - } - - /** - * Create a new McpAsyncClient with the given transport and session request-response - * timeout. - * @param transport the transport to use. - * @param requestTimeout the session request-response timeout. - * @param clientInfo the client implementation information. - * @param clientCapabilities the client capabilities. - * @param roots the roots. - * @param toolsChangeConsumers the tools change consumers. - * @param resourcesChangeConsumers the resources change consumers. - * @param promptsChangeConsumers the prompts change consumers. - * @param loggingConsumers the logging consumers. - * @param samplingHandler the sampling handler. - * @deprecated Use {@link McpClient#async(ClientMcpTransport)} to obtain an instance. - */ - @Deprecated - public McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Implementation clientInfo, - ClientCapabilities clientCapabilities, Map roots, - List>> toolsChangeConsumers, - List>> resourcesChangeConsumers, - List>> promptsChangeConsumers, - List> loggingConsumers, - Function samplingHandler) { - - Assert.notNull(transport, "Transport must not be null"); - Assert.notNull(requestTimeout, "Request timeout must not be null"); - Assert.notNull(clientInfo, "Client info must not be null"); - - this.protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - this.clientInfo = clientInfo; - - this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities - : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new RootCapabilities(false) : null, - samplingHandler != null ? new Sampling() : null); - - this.transport = transport; - - this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); - - // Request Handlers - Map> requestHandlers = new HashMap<>(); - - // Roots List Request Handler - if (this.clientCapabilities.roots() != null) { - requestHandlers.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler()); - } - - // Sampling Handler - if (this.clientCapabilities.sampling() != null) { - if (samplingHandler == null) { - throw new McpError("Sampling handler must not be null when client capabilities include sampling"); - } - this.samplingHandler = r -> Mono.fromCallable(() -> samplingHandler.apply(r)) - .subscribeOn(Schedulers.boundedElastic()); - requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); - } - - // Notification Handlers - Map notificationHandlers = new HashMap<>(); - - // Tools Change Notification - List>> toolsChangeConsumersFinal = new ArrayList<>(); - toolsChangeConsumersFinal.add((notification) -> logger.info("Tools changed: {}", notification)); - if (!Utils.isEmpty(toolsChangeConsumers)) { - toolsChangeConsumersFinal.addAll(toolsChangeConsumers); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, - toolsChangeNotificationHandler(toolsChangeConsumersFinal)); - - // Resources Change Notification - List>> resourcesChangeConsumersFinal = new ArrayList<>(); - resourcesChangeConsumersFinal.add((notification) -> logger.info("Resources changed: {}", notification)); - if (!Utils.isEmpty(resourcesChangeConsumers)) { - resourcesChangeConsumersFinal.addAll(resourcesChangeConsumers); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, - resourcesChangeNotificationHandler(resourcesChangeConsumersFinal)); - - // Prompts Change Notification - List>> promptsChangeConsumersFinal = new ArrayList<>(); - promptsChangeConsumersFinal.add((notification) -> logger.info("Prompts changed: {}", notification)); - if (!Utils.isEmpty(promptsChangeConsumers)) { - promptsChangeConsumersFinal.addAll(promptsChangeConsumers); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, - promptsChangeNotificationHandler(promptsChangeConsumersFinal)); - - // Utility Logging Notification - List> loggingConsumersFinal = new ArrayList<>(); - loggingConsumersFinal.add((notification) -> logger.info("Logging: {}", notification)); - if (!Utils.isEmpty(loggingConsumers)) { - loggingConsumersFinal.addAll(loggingConsumers); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingNotificationHandler(loggingConsumersFinal)); - - this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers); - - } - - // -------------------------- - // Lifecycle - // -------------------------- - /** - * The initialization phase MUST be the first interaction between client and server. - * During this phase, the client and server: - *

      - *
    • Establish protocol version compatibility
    • - *
    • Exchange and negotiate capabilities
    • - *
    • Share implementation details
    • - *
    - *
    - * The client MUST initiate this phase by sending an initialize request containing: - *
      - *
    • The protocol version the client supports
    • - *
    • The client's capabilities
    • - *
    • Client implementation information
    • - *
    - * - * The server MUST respond with its own capabilities and information: - * {@link McpSchema.ServerCapabilities}.
    - * After successful initialization, the client MUST send an initialized notification - * to indicate it is ready to begin normal operations. - * - *
    - * - * Initialization - * Spec - * @return the initialize result. - */ - public Mono initialize() { - - String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off - latestVersion, - this.clientCapabilities, - this.clientInfo); // @formatter:on - - Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, new TypeReference() { - }); - - return result.flatMap(initializeResult -> { - - this.serverCapabilities = initializeResult.capabilities(); - this.serverInfo = initializeResult.serverInfo(); - - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", - initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), - initializeResult.instructions()); - - if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); - } - else { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) - .thenReturn(initializeResult); - } - }); - } - - /** - * Get the server capabilities that define the supported features and functionality. - * @return The server capabilities - */ - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - /** - * Get the server implementation information. - * @return The server implementation details - */ - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - - /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - */ - public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; - } - - /** - * Get the client implementation information. - * @return The client implementation details - */ - public McpSchema.Implementation getClientInfo() { - return this.clientInfo; - } - - /** - * Closes the client connection immediately. - */ - public void close() { - this.mcpSession.close(); - } - - /** - * Gracefully closes the client connection. - * @return A Mono that completes when the connection is closed - */ - public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); - } - - // -------------------------- - // Basic Utilites - // -------------------------- - - /** - * Sends a ping request to the server. - * @return A Mono that completes with the server's ping response - */ - public Mono ping() { - return this.mcpSession.sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - }); - } - - // -------------------------- - // Roots - // -------------------------- - /** - * Adds a new root to the client's root list. - * @param root The root to add - * @return A Mono that completes when the root is added and notifications are sent - */ - public Mono addRoot(Root root) { - - if (root == null) { - return Mono.error(new McpError("Root must not be null")); - } - - if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); - } - - if (this.roots.containsKey(root.uri())) { - return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists")); - } - - this.roots.put(root.uri(), root); - - logger.info("Added root: {}", root); - - if (this.clientCapabilities.roots().listChanged()) { - return this.rootsListChangedNotification(); - } - return Mono.empty(); - } - - /** - * Removes a root from the client's root list. - * @param rootUri The URI of the root to remove - * @return A Mono that completes when the root is removed and notifications are sent - */ - public Mono removeRoot(String rootUri) { - - if (rootUri == null) { - return Mono.error(new McpError("Root uri must not be null")); - } - - if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); - } - - Root removed = this.roots.remove(rootUri); - - if (removed != null) { - logger.info("Removed Root: {}", rootUri); - if (this.clientCapabilities.roots().listChanged()) { - return this.rootsListChangedNotification(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Root with uri '" + rootUri + "' not found")); - } - - /** - * Manually sends a roots/list_changed notification. The addRoot and removeRoot - * methods automatically send the roots/list_changed notification. - * @return A Mono that completes when the notification is sent - */ - public Mono rootsListChangedNotification() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED); - } - - private RequestHandler rootsListRequestHandler() { - return params -> { - McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); - - List roots = this.roots.values().stream().toList(); - - return Mono.just(new McpSchema.ListRootsResult(roots)); - }; - } - - // -------------------------- - // Sampling - // -------------------------- - private RequestHandler samplingCreateMessageHandler() { - return params -> { - McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); - - return this.samplingHandler.apply(request); - }; - } - - // -------------------------- - // Tools - // -------------------------- - private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { - }; - - private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Calls a tool provided by the server. Tools enable servers to expose executable - * functionality that can interact with external systems, perform computations, and - * take actions in the real world. - * @param callToolRequest The request containing: - name: The name of the tool to call - * (must match a tool name from tools/list) - arguments: Arguments that conform to the - * tool's input schema - * @return A Mono that emits the tool execution result containing: - content: List of - * content items (text, images, or embedded resources) representing the tool's output - * - isError: Boolean indicating if the execution failed (true) or succeeded - * (false/absent) - */ - public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); - } - - /** - * Retrieves the list of all tools provided by the server. - * @return A Mono that emits the list of tools result containing: - tools: List of - * available tools, each with a name, description, and input schema - nextCursor: - * Optional cursor for pagination if more tools are available - */ - public Mono listTools() { - return this.listTools(null); - } - - /** - * Retrieves a paginated list of tools provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of tools result containing: - tools: List of - * available tools, each with a name, description, and input schema - nextCursor: - * Optional cursor for pagination if more tools are available - */ - public Mono listTools(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); - } - - /** - * Creates a notification handler for tools/list_changed notifications from the - * server. When the server's available tools change, it sends a notification to inform - * connected clients. This handler automatically fetches the updated tool list and - * distributes it to all registered consumers. - * @param toolsChangeConsumers List of consumers that will be notified when the tools - * list changes. Each consumer receives the complete updated list of tools. - * @return A NotificationHandler that processes tools/list_changed notifications by: - * 1. Fetching the current list of tools from the server 2. Distributing the updated - * list to all registered consumers 3. Handling any errors that occur during this - * process - */ - @Deprecated - private NotificationHandler toolsChangeNotificationHandler( - List>> toolsChangeConsumers) { - - return params -> listTools().flatMap(listToolsResult -> Mono.fromRunnable(() -> { - for (Consumer> toolsChangeConsumer : toolsChangeConsumers) { - toolsChangeConsumer.accept(listToolsResult.tools()); - } - }).subscribeOn(Schedulers.boundedElastic())).onErrorResume(error -> { - logger.error("Error handling tools list change notification", error); - return Mono.empty(); - }).then(); // Convert to Mono - } - - /** - * Creates a notification handler for tools/list_changed notifications from the - * server. When the server's available tools change, it sends a notification to inform - * connected clients. This handler automatically fetches the updated tool list and - * distributes it to all registered consumers. - * @param toolsChangeConsumers List of consumers that will be notified when the tools - * list changes. Each consumer receives the complete updated list of tools. - * @return A NotificationHandler that processes tools/list_changed notifications by: - * 1. Fetching the current list of tools from the server 2. Distributing the updated - * list to all registered consumers 3. Handling any errors that occur during this - * process - */ - private NotificationHandler asyncToolsChangeNotificationHandler( - List, Mono>> toolsChangeConsumers) { - // TODO: params are not used yet - return params -> listTools().flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers) - .flatMap(consumer -> consumer.apply(listToolsResult.tools())) - .onErrorResume(error -> { - logger.error("Error handling tools list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // -------------------------- - // Resources - // -------------------------- - - private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { - }; - - private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { - }; - - private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Send a resources/list request. - * @return A Mono that completes with the list of resources result - */ - public Mono listResources() { - return this.listResources(null); - } - - /** - * Send a resources/list request. - * @param cursor the cursor for pagination - * @return A Mono that completes with the list of resources result - */ - public Mono listResources(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_RESOURCES_RESULT_TYPE_REF); - } - - /** - * Send a resources/read request. - * @param resource the resource to read - * @return A Mono that completes with the resource content - */ - public Mono readResource(McpSchema.Resource resource) { - return this.readResource(new McpSchema.ReadResourceRequest(resource.uri())); - } - - /** - * Send a resources/read request. - * @param readResourceRequest the read resource request - * @return A Mono that completes with the resource content - */ - public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, - READ_RESOURCE_RESULT_TYPE_REF); - } - - /** - * Resource templates allow servers to expose parameterized resources using URI - * templates. Arguments may be auto-completed through the completion API. - * - * Request a list of resource templates the server has. - * @return A Mono that completes with the list of resource templates result - */ - public Mono listResourceTemplates() { - return this.listResourceTemplates(null); - } - - /** - * Resource templates allow servers to expose parameterized resources using URI - * templates. Arguments may be auto-completed through the completion API. - * - * Request a list of resource templates the server has. - * @param cursor the cursor for pagination - * @return A Mono that completes with the list of resource templates result - */ - public Mono listResourceTemplates(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, - new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); - } - - /** - * List Changed Notification. When the list of available resources changes, servers - * that declared the listChanged capability SHOULD send a notification. - * @return A Mono that completes when the notification is sent - */ - public Mono sendResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED); - } - - /** - * Subscriptions. The protocol supports optional subscriptions to resource changes. - * Clients can subscribe to specific resources and receive notifications when they - * change. - * - * Send a resources/subscribe request. - * @param subscribeRequest the subscribe request contains the uri of the resource to - * subscribe to - * @return A Mono that completes when the subscription is complete - */ - public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE); - } - - /** - * Send a resources/unsubscribe request. - * @param unsubscribeRequest the unsubscribe request contains the uri of the resource - * to unsubscribe from - * @return A Mono that completes when the unsubscription is complete - */ - public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, - VOID_TYPE_REFERENCE); - } - - @Deprecated - private NotificationHandler resourcesChangeNotificationHandler( - List>> resourcesChangeConsumers) { - - return params -> listResources().flatMap(listResourcesResult -> Mono.fromRunnable(() -> { - for (Consumer> resourceChangeConsumer : resourcesChangeConsumers) { - resourceChangeConsumer.accept(listResourcesResult.resources()); - } - }).subscribeOn(Schedulers.boundedElastic())).onErrorResume(error -> { - logger.error("Error handling resources list change notification", error); - return Mono.empty(); - }).then(); - } - - private NotificationHandler asyncResourcesChangeNotificationHandler( - List, Mono>> resourcesChangeConsumers) { - return params -> listResources().flatMap(listResourcesResult -> Flux.fromIterable(resourcesChangeConsumers) - .flatMap(consumer -> consumer.apply(listResourcesResult.resources())) - .onErrorResume(error -> { - logger.error("Error handling resources list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // -------------------------- - // Prompts - // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * List all available prompts. - * @return A Mono that completes with the list of prompts result - */ - public Mono listPrompts() { - return this.listPrompts(null); - } - - /** - * List all available prompts. - * @param cursor the cursor for pagination - * @return A Mono that completes with the list of prompts result - */ - public Mono listPrompts(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), - LIST_PROMPTS_RESULT_TYPE_REF); - } - - /** - * Get a prompt by its id. - * @param getPromptRequest the get prompt request - * @return A Mono that completes with the get prompt result - */ - public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF); - } - - /** - * (Server) An optional notification from the server to the client, informing it that - * the list of prompts it offers has changed. This may be issued by servers without - * any previous subscription from the client. - * @return A Mono that completes when the notification is sent - */ - public Mono promptListChangedNotification() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED); - } - - @Deprecated - private NotificationHandler promptsChangeNotificationHandler( - List>> promptsChangeConsumers) { - - return params -> {// @formatter:off - return listPrompts().flatMap(listPromptsResult -> Mono.fromRunnable(() -> { - for (Consumer> promptChangeConsumer : promptsChangeConsumers) { - promptChangeConsumer.accept(listPromptsResult.prompts()); - } - }).subscribeOn(Schedulers.boundedElastic())).onErrorResume(error -> { - logger.error("Error handling prompts list change notification", error); - return Mono.empty(); - }).then(); // Convert to Mono - }; // @formatter:on - } - - private NotificationHandler asyncPromptsChangeNotificationHandler( - List, Mono>> promptsChangeConsumers) { - return params -> listPrompts().flatMap(listPromptsResult -> Flux.fromIterable(promptsChangeConsumers) - .flatMap(consumer -> consumer.apply(listPromptsResult.prompts())) - .onErrorResume(error -> { - logger.error("Error handling prompts list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // -------------------------- - // Logging - // -------------------------- - /** - * Create a notification handler for logging notifications from the server. This - * handler automatically distributes logging messages to all registered consumers. - * @param loggingConsumers List of consumers that will be notified when a logging - * message is received. Each consumer receives the logging message notification. - * @return A NotificationHandler that processes log notifications by distributing the - * message to all registered consumers - */ - private NotificationHandler loggingNotificationHandler( - List> loggingConsumers) { - - return params -> { - - McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { - }); - - return Mono.fromRunnable(() -> { - for (Consumer loggingConsumer : loggingConsumers) { - loggingConsumer.accept(loggingMessageNotification); - } - }).subscribeOn(Schedulers.boundedElastic()).then(); - - }; - } - - /** - * Create a notification handler for logging notifications from the server. This - * handler automatically distributes logging messages to all registered consumers. - * @param loggingConsumers List of consumers that will be notified when a logging - * message is received. Each consumer receives the logging message notification. - * @return A NotificationHandler that processes log notifications by distributing the - * message to all registered consumers - */ - private NotificationHandler asyncLoggingNotificationHandler( - List>> loggingConsumers) { - - return params -> { - McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { - }); - - return Flux.fromIterable(loggingConsumers) - .flatMap(consumer -> consumer.apply(loggingMessageNotification)) - .then(); - }; - } - - /** - * Client can set the minimum logging level it wants to receive from the server. - * @param loggingLevel the min logging level - */ - public Mono setLoggingLevel(LoggingLevel loggingLevel) { - Assert.notNull(loggingLevel, "Logging level must not be null"); - - String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { - }); - - Map params = Map.of("level", levelName); - - return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); - } - - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } - -} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/McpAsyncServer.java b/mcp/src/main/java/org/springframework/ai/mcp/server/McpAsyncServer.java deleted file mode 100644 index 00971b7cc..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/McpAsyncServer.java +++ /dev/null @@ -1,928 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.server; - -import java.time.Duration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.Consumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.server.McpServer.PromptRegistration; -import org.springframework.ai.mcp.server.McpServer.ResourceRegistration; -import org.springframework.ai.mcp.server.McpServer.ToolRegistration; -import org.springframework.ai.mcp.spec.DefaultMcpSession; -import org.springframework.ai.mcp.spec.DefaultMcpSession.NotificationHandler; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.LoggingLevel; -import org.springframework.ai.mcp.spec.McpSchema.LoggingMessageNotification; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.ServerMcpTransport; -import org.springframework.ai.mcp.util.Utils; - -/** - * The Model Context Protocol (MCP) server implementation that provides asynchronous - * communication using Project Reactor's Mono and Flux types. - * - *

    - * This server implements the MCP specification, enabling AI models to expose tools, - * resources, and prompts through a standardized interface. Key features include: - *

      - *
    • Asynchronous communication using reactive programming patterns - *
    • Dynamic tool registration and management - *
    • Resource handling with URI-based addressing - *
    • Prompt template management - *
    • Real-time client notifications for state changes - *
    • Structured logging with configurable severity levels - *
    • Support for client-side AI model sampling - *
    - * - *

    - * The server follows a lifecycle: - *

      - *
    1. Initialization - Accepts client connections and negotiates capabilities - *
    2. Normal Operation - Handles client requests and sends notifications - *
    3. Graceful Shutdown - Ensures clean connection termination - *
    - * - *

    - * This implementation uses Project Reactor for non-blocking operations, making it - * suitable for high-throughput scenarios and reactive applications. All operations return - * Mono or Flux types that can be composed into reactive pipelines. - * - *

    - * The server supports runtime modification of its capabilities through methods like - * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying - * connected clients of changes when configured to do so. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @see McpServer - * @see McpSchema - * @see DefaultMcpSession - */ -public class McpAsyncServer { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; - - private final ServerMcpTransport transport; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private McpSchema.ClientCapabilities clientCapabilities; - - private McpSchema.Implementation clientInfo; - - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransport The transport layer implementation for MCP communication. - * @param features The MCP server supported features. - */ - McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (!Utils.isEmpty(this.resources)) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - } - - // Add resource templates API handlers if provided. - if (!Utils.isEmpty(this.resourceTemplates)) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (!Utils.isEmpty(this.prompts)) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); - } - - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransport The transport layer implementation for MCP communication - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when the - * roots list changes - * @deprecated Use {@link McpServer#sync(ServerMcpTransport)} or - * {@link McpServer#async(ServerMcpTransport)} to create a new server instance. - */ - @Deprecated - public McpAsyncServer(ServerMcpTransport mcpTransport, McpSchema.Implementation serverInfo, - McpSchema.ServerCapabilities serverCapabilities, List tools, - Map resources, List resourceTemplates, - Map prompts, List>> rootsChangeConsumers) { - - this.protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - this.serverInfo = serverInfo; - if (!Utils.isEmpty(tools)) { - this.tools.addAll(McpServer.mapDeprecatedTools(tools)); - } - if (!Utils.isEmpty(resources)) { - this.resources.putAll(McpServer.mapDeprecatedResources(resources)); - } - if (!Utils.isEmpty(resourceTemplates)) { - this.resourceTemplates.addAll(resourceTemplates); - } - if (!Utils.isEmpty(prompts)) { - this.prompts.putAll(McpServer.mapDeprecatedPrompts(prompts)); - } - - this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities : new McpSchema.ServerCapabilities( - null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable logging - // by default - !Utils.isEmpty(this.prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, - !Utils.isEmpty(this.resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) - : null, - !Utils.isEmpty(this.tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, initializeRequestHandler()); - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (!Utils.isEmpty(this.resources)) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - } - - // Add resource templates API handlers if provided. - if (!Utils.isEmpty(this.resourceTemplates)) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (!Utils.isEmpty(this.prompts)) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots)); - } - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - rootsListChnagedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { - return params -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); - }; - } - - @Deprecated - private DefaultMcpSession.RequestHandler initializeRequestHandler() { - return params -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); - - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - if (!McpSchema.LATEST_PROTOCOL_VERSION.equals(initializeRequest.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from client: " + initializeRequest.protocolVersion())) - .publishOn(Schedulers.boundedElastic()); - } - - return Mono - .just(new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, this.serverCapabilities, - this.serverInfo, null)) - .publishOn(Schedulers.boundedElastic()); - }; - } - - /** - * Get the server capabilities that define the supported features and functionality. - * @return The server capabilities - */ - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - /** - * Get the server implementation information. - * @return The server implementation details - */ - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - - /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - */ - public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; - } - - /** - * Get the client implementation information. - * @return The client implementation details - */ - public McpSchema.Implementation getClientInfo() { - return this.clientInfo; - } - - /** - * Gracefully closes the server, allowing any in-progress operations to complete. - * @return A Mono that completes when the server has been closed - */ - public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); - } - - /** - * Close the server immediately. - */ - public void close() { - this.mcpSession.close(); - } - - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Retrieves the list of all roots provided by the client. - * @return A Mono that emits the list of roots result. - */ - public Mono listRoots() { - return this.listRoots(null); - } - - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of roots result containing - */ - public Mono listRoots(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); - } - - @Deprecated - private NotificationHandler rootsListChnagedNotificationHandler( - List>> rootsChangeConsumers) { - - return params -> { - return listRoots().flatMap(listRootsResult -> Mono.fromRunnable(() -> { - rootsChangeConsumers.stream().forEach(consumer -> consumer.accept(listRootsResult.roots())); - }).subscribeOn(Schedulers.boundedElastic())).onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }).then(); - }; - } - - private NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(toolRegistration); - logger.info("Added tool handler: {}", toolRegistration.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - * @deprecated Use {@link #addTool(McpServerFeatures.AsyncToolRegistration)}. - */ - @Deprecated - public Mono addTool(ToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono.error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(McpServer.mapDeprecatedTool(toolRegistration)); - logger.info("Added tool handler: {}", toolRegistration.tool().name()); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - - /** - * Remove a tool handler at runtime. - * @param toolName The name of the tool handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); - if (removed) { - logger.info("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyToolsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler toolsListRequestHandler() { - return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { - return params -> { - McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolRegistration.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } - - // --------------------------------------- - // Resource Management - // --------------------------------------- - - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - logger.info("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated Use {@link #addResource(McpServerFeatures.AsyncResourceRegistration)}. - */ - @Deprecated - public Mono addResource(ResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - if (this.resources.containsKey(resourceHandler.resource().uri())) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - - this.resources.put(resourceHandler.resource().uri(), McpServer.mapDeprecatedResource(resourceHandler)); - logger.info("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - - /** - * Remove a resource handler at runtime. - * @param resourceUri The URI of the resource handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.info("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { - return params -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { - return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); - - } - - private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { - return params -> { - McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- - - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); - if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - logger.info("Added prompt handler: {}", promptRegistration.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated Use {@link #addPrompt(McpServerFeatures.AsyncPromptRegistration)}. - */ - @Deprecated - public Mono addPrompt(PromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - if (this.prompts.containsKey(promptRegistration.prompt().name())) { - return Mono - .error(new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - this.prompts.put(promptRegistration.prompt().name(), McpServer.mapDeprecatedPrompt(promptRegistration)); - - logger.info("Added prompt handler: {}", promptRegistration.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - } - - /** - * Remove a prompt handler at runtime. - * @param promptName The name of the prompt handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.info("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyPromptsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler promptsListRequestHandler() { - return params -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { - return params -> { - McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); - if (registration == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return registration.promptHandler().apply(promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - */ - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - Map params = this.transport.unmarshalFrom(loggingMessageNotification, - new TypeReference>() { - }); - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); - } - - /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. - * @return A handler that processes logging level change requests - */ - private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { - return params -> { - this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { - }); - - return Mono.empty(); - }; - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This - * flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. - * @param createMessageRequest The request to create a new message - * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - */ - public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - - if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); - } - if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); - } - - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } - -} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/server/McpServerFeatures.java b/mcp/src/main/java/org/springframework/ai/mcp/server/McpServerFeatures.java deleted file mode 100644 index 2dd4e66ee..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/server/McpServerFeatures.java +++ /dev/null @@ -1,401 +0,0 @@ -package org.springframework.ai.mcp.server; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.util.Assert; -import org.springframework.ai.mcp.util.Utils; - -/** - * MCP server features specification that a particular server can choose to support. - * - * @author Dariusz Jędrzejczyk - */ -public class McpServerFeatures { - - /** - * Asynchronous server features specification. - * - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when the - * roots list changes - */ - record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - List, Mono>> rootsChangeConsumers) { - - /** - * Create an instance and validate the arguments. - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when - * the roots list changes - */ - Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - List, Mono>> rootsChangeConsumers) { - - Assert.notNull(serverInfo, "Server info must not be null"); - - this.serverInfo = serverInfo; - this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default - !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, - !Utils.isEmpty(resources) - ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); - - this.tools = (tools != null) ? tools : List.of(); - this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); - this.prompts = (prompts != null) ? prompts : Map.of(); - this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); - } - - /** - * Convert a synchronous specification into an asynchronous one and provide - * blocking code offloading to prevent accidental blocking of the non-blocking - * transport. - * @param syncSpec a potentially blocking, synchronous specification. - * @return a specification which is protected from blocking calls specified by the - * user. - */ - static Async fromSync(Sync syncSpec) { - List tools = new ArrayList<>(); - for (var tool : syncSpec.tools()) { - tools.add(AsyncToolRegistration.fromSync(tool)); - } - - Map resources = new HashMap<>(); - syncSpec.resources().forEach((key, resource) -> { - resources.put(key, AsyncResourceRegistration.fromSync(resource)); - }); - - Map prompts = new HashMap<>(); - syncSpec.prompts().forEach((key, prompt) -> { - prompts.put(key, AsyncPromptRegistration.fromSync(prompt)); - }); - - List, Mono>> rootChangeConsumers = new ArrayList<>(); - - for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { - rootChangeConsumers.add(list -> Mono.fromRunnable(() -> rootChangeConsumer.accept(list)) - .subscribeOn(Schedulers.boundedElastic())); - } - - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, rootChangeConsumers); - } - } - - /** - * Synchronous server features specification. - * - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when the - * roots list changes - */ - record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, - List resourceTemplates, - Map prompts, - List>> rootsChangeConsumers) { - - /** - * Create an instance and validate the arguments. - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations - * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations - * @param rootsChangeConsumers The list of consumers that will be notified when - * the roots list changes - */ - Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, - List resourceTemplates, - Map prompts, - List>> rootsChangeConsumers) { - - Assert.notNull(serverInfo, "Server info must not be null"); - - this.serverInfo = serverInfo; - this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default - !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, - !Utils.isEmpty(resources) - ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); - - this.tools = (tools != null) ? tools : new ArrayList<>(); - this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); - this.prompts = (prompts != null) ? prompts : new HashMap<>(); - this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); - } - - } - - /** - * Registration of a tool with its asynchronous handler function. Tools are the - * primary way for MCP servers to expose functionality to AI models. Each tool - * represents a specific capability, such as: - *

      - *
    • Performing calculations - *
    • Accessing external APIs - *
    • Querying databases - *
    • Manipulating files - *
    • Executing system commands - *
    - * - *

    - * Example tool registration:

    {@code
    -	 * new McpServerFeatures.AsyncToolRegistration(
    -	 *     new Tool(
    -	 *         "calculator",
    -	 *         "Performs mathematical calculations",
    -	 *         new JsonSchemaObject()
    -	 *             .required("expression")
    -	 *             .property("expression", JsonSchemaType.STRING)
    -	 *     ),
    -	 *     args -> {
    -	 *         String expr = (String) args.get("expression");
    -	 *         return Mono.just(new CallToolResult("Result: " + evaluate(expr)));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - */ - public record AsyncToolRegistration(McpSchema.Tool tool, - Function, Mono> call) { - - static AsyncToolRegistration fromSync(SyncToolRegistration tool) { - // FIXME: This is temporary, proper validation should be implemented - if (tool == null) { - return null; - } - return new AsyncToolRegistration(tool.tool(), - map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a resource with its asynchronous handler function. Resources - * provide context to AI models by exposing data such as: - *
      - *
    • File contents - *
    • Database records - *
    • API responses - *
    • System information - *
    • Application state - *
    - * - *

    - * Example resource registration:

    {@code
    -	 * new McpServerFeatures.AsyncResourceRegistration(
    -	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     request -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return Mono.just(new ReadResourceResult(content));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - */ - public record AsyncResourceRegistration(McpSchema.Resource resource, - Function> readHandler) { - - static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { - // FIXME: This is temporary, proper validation should be implemented - if (resource == null) { - return null; - } - return new AsyncResourceRegistration(resource.resource(), - req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a prompt template with its asynchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
      - *
    • Consistent message formatting - *
    • Parameter substitution - *
    • Context injection - *
    • Response formatting - *
    • Instruction templating - *
    - * - *

    - * Example prompt registration:

    {@code
    -	 * new McpServerFeatures.AsyncPromptRegistration(
    -	 *     new Prompt("analyze", "Code analysis template"),
    -	 *     request -> {
    -	 *         String code = request.getArguments().get("code");
    -	 *         return Mono.just(new GetPromptResult(
    -	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    -	 *         ));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - */ - public record AsyncPromptRegistration(McpSchema.Prompt prompt, - Function> promptHandler) { - - static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { - // FIXME: This is temporary, proper validation should be implemented - if (prompt == null) { - return null; - } - return new AsyncPromptRegistration(prompt.prompt(), - req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - } - - /** - * Registration of a tool with its synchronous handler function. Tools are the primary - * way for MCP servers to expose functionality to AI models. Each tool represents a - * specific capability, such as: - *
      - *
    • Performing calculations - *
    • Accessing external APIs - *
    • Querying databases - *
    • Manipulating files - *
    • Executing system commands - *
    - * - *

    - * Example tool registration:

    {@code
    -	 * new McpServerFeatures.SyncToolRegistration(
    -	 *     new Tool(
    -	 *         "calculator",
    -	 *         "Performs mathematical calculations",
    -	 *         new JsonSchemaObject()
    -	 *             .required("expression")
    -	 *             .property("expression", JsonSchemaType.STRING)
    -	 *     ),
    -	 *     args -> {
    -	 *         String expr = (String) args.get("expression");
    -	 *         return new CallToolResult("Result: " + evaluate(expr));
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - */ - public record SyncToolRegistration(McpSchema.Tool tool, - Function, McpSchema.CallToolResult> call) { - } - - /** - * Registration of a resource with its synchronous handler function. Resources provide - * context to AI models by exposing data such as: - *
      - *
    • File contents - *
    • Database records - *
    • API responses - *
    • System information - *
    • Application state - *
    - * - *

    - * Example resource registration:

    {@code
    -	 * new McpServerFeatures.SyncResourceRegistration(
    -	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     request -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return new ReadResourceResult(content);
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - */ - public record SyncResourceRegistration(McpSchema.Resource resource, - Function readHandler) { - } - - /** - * Registration of a prompt template with its synchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
      - *
    • Consistent message formatting - *
    • Parameter substitution - *
    • Context injection - *
    • Response formatting - *
    • Instruction templating - *
    - * - *

    - * Example prompt registration:

    {@code
    -	 * new McpServerFeatures.SyncPromptRegistration(
    -	 *     new Prompt("analyze", "Code analysis template"),
    -	 *     request -> {
    -	 *         String code = request.getArguments().get("code");
    -	 *         return new GetPromptResult(
    -	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    -	 *         );
    -	 *     }
    -	 * )
    -	 * }
    - * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - */ - public record SyncPromptRegistration(McpSchema.Prompt prompt, - Function promptHandler) { - } - -} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/ClientMcpTransport.java b/mcp/src/main/java/org/springframework/ai/mcp/spec/ClientMcpTransport.java deleted file mode 100644 index 4b5af3b15..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/ClientMcpTransport.java +++ /dev/null @@ -1,25 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.spec; - -/** - * Marker interface for the client-side MCP transport. - * - * @author Christian Tzolov - */ -public interface ClientMcpTransport extends McpTransport { - -} diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpError.java b/mcp/src/main/java/org/springframework/ai/mcp/spec/McpError.java deleted file mode 100644 index 94170ff02..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/McpError.java +++ /dev/null @@ -1,37 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.spec; - -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCResponse.JSONRPCError; - -public class McpError extends RuntimeException { - - private JSONRPCError jsonRpcError; - - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.message()); - this.jsonRpcError = jsonRpcError; - } - - public McpError(Object error) { - super(error.toString()); - } - - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } - -} \ No newline at end of file diff --git a/mcp/src/main/java/org/springframework/ai/mcp/spec/ServerMcpTransport.java b/mcp/src/main/java/org/springframework/ai/mcp/spec/ServerMcpTransport.java deleted file mode 100644 index bb3d829d8..000000000 --- a/mcp/src/main/java/org/springframework/ai/mcp/spec/ServerMcpTransport.java +++ /dev/null @@ -1,25 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.spec; - -/** - * Marker interface for the server-side MCP transport. - * - * @author Christian Tzolov - */ -public interface ServerMcpTransport extends McpTransport { - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java new file mode 100644 index 000000000..12f30d12f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} + * interfaces. + */ +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { + + private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); + + private final List sent = new ArrayList<>(); + + private final BiConsumer interceptor; + + public MockMcpTransport() { + this((t, msg) -> { + }); + } + + public MockMcpTransport(BiConsumer interceptor) { + this.interceptor = interceptor; + } + + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { + if (inbound.tryEmitNext(message).isFailure()) { + throw new RuntimeException("Failed to process incoming message " + message); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + sent.add(message); + interceptor.accept(this, message); + return Mono.empty(); + } + + public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { + return (JSONRPCRequest) getLastSentMessage(); + } + + public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() { + return (JSONRPCNotification) getLastSentMessage(); + } + + public McpSchema.JSONRPCMessage getLastSentMessage() { + return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; + } + + private volatile boolean connected = false; + + @Override + public Mono connect(Function, Mono> handler) { + if (connected) { + return Mono.error(new IllegalStateException("Already connected")); + } + connected = true; + return inbound.asFlux() + .flatMap(message -> Mono.just(message).transform(handler)) + .doFinally(signal -> connected = false) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + connected = false; + inbound.tryEmitComplete(); + // Wait for all subscribers to complete + return Mono.empty(); + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return new ObjectMapper().convertValue(data, typeRef); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java new file mode 100644 index 000000000..ac7b9e5ec --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -0,0 +1,492 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncClientTests { + + private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); + + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testMcpAsyncClientState() { + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); + } + + @Test + void testListPromptsWithoutInitialization() { + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + } + + @Test + void testListPrompts() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); + } + + @Test + void testGetPromptWithoutInitialization() { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + } + + @Test + void testGetPrompt() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); + + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); + } + + @Test + @Disabled + void testReadResource() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); + + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); + }); + } + + @Test + void testInitializeWithSamplingCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + + @Test + void testInitializeWithAllCapabilities() { + var capabilities = ClientCapabilities.builder() + .experimental(Map.of("feature", "test")) + .roots(true) + .sampling() + .build(); + + Function> samplingHandler = request -> Mono + .just(CreateMessageResult.builder().message("test").model("test-model").build()); + + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> + + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); + + StepVerifier.create(testAllLevels).verifyComplete(); + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); + + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java new file mode 100644 index 000000000..24c161ebf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -0,0 +1,450 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ListResourceTemplatesResult; +import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; +import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for MCP Client Session functionality. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpSyncClientTests { + + private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; + + abstract protected McpClientTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpSyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.SyncSpec builder = McpClient.sync(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + + } + + @AfterEach + void tearDown() { + onClose(); + } + + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function blockingOperation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + // This works without actually waiting but executes all the + // tasks pending execution on the VirtualTimeScheduler. + // It is possible to execute the blocking code from the operation + // because it is blocked on a dedicated Scheduler and the main + // flow is not blocked and uses the VirtualTimeScheduler. + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request timeout must not be null"); + } + + @Test + void testListToolsWithoutInitialization() { + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + } + + @Test + void testListTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); + }); + } + + @Test + void testCallToolsWithoutInitialization() { + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); + } + + @Test + void testCallTools() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + + assertThat(toolResult).isNotNull().satisfies(result -> { + + assertThat(result.content()).hasSize(1); + + TextContent content = (TextContent) result.content().get(0); + + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); + }); + } + + @Test + void testPingWithoutInitialization() { + verifyCallTimesOut(client -> client.ping(), "pinging the server"); + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); + } + + @Test + void testCallToolWithoutInitialization() { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }); + } + + @Test + void testCallToolWithInvalidTool() { + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); + } + + @Test + void testRootsListChangedWithoutInitialization() { + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); + } + + @Test + void testRootsListChanged() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); + } + + @Test + void testListResourcesWithoutInitialization() { + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + } + + @Test + void testListResources() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }); + } + + @Test + void testClientSessionState() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); + } + + @Test + void testInitializeWithRootsListProviders() { + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { + + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); + } + + @Test + void testAddRootWithNullValue() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); + } + + @Test + void testRemoveRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testRemoveNonExistentRoot() { + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); + } + + @Test + void testReadResourceWithoutInitialization() { + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + } + + @Test + void testReadResource() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); + } + + @Test + void testListResourceTemplatesWithoutInitialization() { + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + } + + @Test + void testListResourceTemplates() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); + } + + // @Test + void testResourceSubscription() { + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); + + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); + } + + @Test + void testNotificationHandlers() { + AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); + AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { + + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevelsWithoutInitialization() { + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); + } + + @Test + void testLoggingLevels() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); + } + + @Test + void testLoggingConsumer() { + AtomicBoolean logReceived = new AtomicBoolean(false); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); + } + + @Test + void testLoggingWithNullNotification() { + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 55% rename from mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpAsyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 86e6976fe..15749d4ff 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -1,35 +1,22 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.ai.mcp.client.transport.HttpClientSseClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; - /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { +class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { String host = "http://localhost:3004"; @@ -41,7 +28,7 @@ class ServletSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java similarity index 55% rename from mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpSyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index f33bc976c..067f92957 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/ServletSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -1,35 +1,22 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.ai.mcp.client.transport.HttpClientSseClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; - /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncClientTests extends AbstractMcpSyncClientTests { +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { String host = "http://localhost:3003"; @@ -41,7 +28,7 @@ class ServletSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java similarity index 68% rename from mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index c5b0ec8e7..b1e82b748 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -1,22 +1,9 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; -import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -25,24 +12,86 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; -import org.springframework.ai.mcp.MockMcpTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Root; - +import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.awaitility.Awaitility.await; class McpAsyncClientResponseHandlerTests { + private static final McpSchema.Implementation SERVER_INFO = new McpSchema.Implementation("test-server", "1.0.0"); + + private static final McpSchema.ServerCapabilities SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .resources(true, true) // Enable both resources and resource templates + .build(); + + private static MockMcpTransport initializationEnabledTransport() { + return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO); + } + + private static MockMcpTransport initializationEnabledTransport(McpSchema.ServerCapabilities mockServerCapabilities, + McpSchema.Implementation mockServerInfo) { + McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, + mockServerCapabilities, mockServerInfo, "Test instructions"); + + return new MockMcpTransport((t, message) -> { + if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) { + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + r.id(), mockInitResult, null); + t.simulateIncomingMessage(initResponse); + } + }); + } + + @Test + void testSuccessfulInitialization() { + McpSchema.Implementation serverInfo = new McpSchema.Implementation("mcp-test-server", "0.0.1"); + McpSchema.ServerCapabilities serverCapabilities = McpSchema.ServerCapabilities.builder() + .tools(false) + .resources(true, true) // Enable both resources and resource templates + .build(); + MockMcpTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); + McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); + + // Verify client is not initialized initially + assertThat(asyncMcpClient.isInitialized()).isFalse(); + + // Start initialization with reactive handling + InitializeResult result = asyncMcpClient.initialize().block(); + + // Verify initialized notification was sent + McpSchema.JSONRPCMessage notificationMessage = transport.getLastSentMessage(); + assertThat(notificationMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) notificationMessage; + assertThat(notification.method()).isEqualTo(McpSchema.METHOD_NOTIFICATION_INITIALIZED); + + // Verify initialization result + assertThat(result).isNotNull(); + assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.capabilities()).isEqualTo(serverCapabilities); + assertThat(result.serverInfo()).isEqualTo(serverInfo); + assertThat(result.instructions()).isEqualTo("Test instructions"); + + // Verify client state after initialization + assertThat(asyncMcpClient.isInitialized()).isTrue(); + assertThat(asyncMcpClient.getServerCapabilities()).isEqualTo(serverCapabilities); + assertThat(asyncMcpClient.getServerInfo()).isEqualTo(serverInfo); + + asyncMcpClient.closeGracefully(); + } + @Test void testToolsChangeNotificationHandling() throws JsonProcessingException { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification List receivedTools = new ArrayList<>(); @@ -54,6 +103,8 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { // Create client with tools change consumer McpAsyncClient asyncMcpClient = McpClient.async(transport).toolsChangeConsumer(toolsChangeConsumer).build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Create a mock tools list that the server will return Map inputSchema = Map.of("type", "object", "properties", Map.of(), "required", List.of()); McpSchema.Tool mockTool = new McpSchema.Tool("test-tool", "Test Tool Description", @@ -74,23 +125,23 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { transport.simulateIncomingMessage(toolsListResponse); // Verify the consumer received the expected tools - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(receivedTools).hasSize(1); - assertThat(receivedTools.get(0).name()).isEqualTo("test-tool"); - assertThat(receivedTools.get(0).description()).isEqualTo("Test Tool Description"); - }); + assertThat(receivedTools).hasSize(1); + assertThat(receivedTools.get(0).name()).isEqualTo("test-tool"); + assertThat(receivedTools.get(0).description()).isEqualTo("Test Tool Description"); asyncMcpClient.closeGracefully(); } @Test void testRootsListRequestHandling() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); McpAsyncClient asyncMcpClient = McpClient.async(transport) .roots(new Root("file:///test/path", "test-root")) .build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_ROOTS_LIST, "test-id", null); @@ -111,7 +162,7 @@ void testRootsListRequestHandling() { @Test void testResourcesChangeNotificationHandling() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); // Create a list to store received resources for verification List receivedResources = new ArrayList<>(); @@ -125,6 +176,8 @@ void testResourcesChangeNotificationHandling() { .resourcesChangeConsumer(resourcesChangeConsumer) .build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Create a mock resources list that the server will return McpSchema.Resource mockResource = new McpSchema.Resource("test://resource", "Test Resource", "A test resource", "text/plain", null); @@ -145,19 +198,17 @@ void testResourcesChangeNotificationHandling() { transport.simulateIncomingMessage(resourcesListResponse); // Verify the consumer received the expected resources - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(receivedResources).hasSize(1); - assertThat(receivedResources.get(0).uri()).isEqualTo("test://resource"); - assertThat(receivedResources.get(0).name()).isEqualTo("Test Resource"); - assertThat(receivedResources.get(0).description()).isEqualTo("A test resource"); - }); + assertThat(receivedResources).hasSize(1); + assertThat(receivedResources.get(0).uri()).isEqualTo("test://resource"); + assertThat(receivedResources.get(0).name()).isEqualTo("Test Resource"); + assertThat(receivedResources.get(0).description()).isEqualTo("A test resource"); asyncMcpClient.closeGracefully(); } @Test void testPromptsChangeNotificationHandling() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); // Create a list to store received prompts for verification List receivedPrompts = new ArrayList<>(); @@ -169,6 +220,8 @@ void testPromptsChangeNotificationHandling() { // Create client with prompts change consumer McpAsyncClient asyncMcpClient = McpClient.async(transport).promptsChangeConsumer(promptsChangeConsumer).build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Create a mock prompts list that the server will return McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description", List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); @@ -188,20 +241,18 @@ void testPromptsChangeNotificationHandling() { transport.simulateIncomingMessage(promptsListResponse); // Verify the consumer received the expected prompts - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(receivedPrompts).hasSize(1); - assertThat(receivedPrompts.get(0).name()).isEqualTo("test-prompt"); - assertThat(receivedPrompts.get(0).description()).isEqualTo("Test Prompt Description"); - assertThat(receivedPrompts.get(0).arguments()).hasSize(1); - assertThat(receivedPrompts.get(0).arguments().get(0).name()).isEqualTo("arg1"); - }); + assertThat(receivedPrompts).hasSize(1); + assertThat(receivedPrompts.get(0).name()).isEqualTo("test-prompt"); + assertThat(receivedPrompts.get(0).description()).isEqualTo("Test Prompt Description"); + assertThat(receivedPrompts.get(0).arguments()).hasSize(1); + assertThat(receivedPrompts.get(0).arguments().get(0).name()).isEqualTo("arg1"); asyncMcpClient.closeGracefully(); } @Test void testSamplingCreateMessageRequestHandling() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); // Create a test sampling handler that echoes back the input Function> samplingHandler = request -> { @@ -216,6 +267,8 @@ void testSamplingCreateMessageRequestHandling() { .sampling(samplingHandler) .build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Create a mock create message request var messageRequest = new McpSchema.CreateMessageRequest( List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), @@ -253,13 +306,15 @@ void testSamplingCreateMessageRequestHandling() { @Test void testSamplingCreateMessageRequestHandlingWithoutCapability() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpTransport transport = initializationEnabledTransport(); // Create client without sampling capability McpAsyncClient asyncMcpClient = McpClient.async(transport) .capabilities(ClientCapabilities.builder().build()) // No sampling capability .build(); + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + // Create a mock create message request var messageRequest = new McpSchema.CreateMessageRequest( List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))), diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java similarity index 86% rename from mcp/src/test/java/org/springframework/ai/mcp/client/McpClientProtocolVersionTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 7a0be90ee..58e486e19 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/McpClientProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -1,33 +1,20 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client; +package io.modelcontextprotocol.client; import java.time.Duration; import java.util.List; +import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.MockMcpTransport; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.InitializeResult; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java new file mode 100644 index 000000000..95230942c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + @Override + protected McpClientTransport createMcpTransport() { + ServerParameters stdioParams = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + return new StdioClientTransport(stdioParams); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java new file mode 100644 index 000000000..925852b5b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { + + @Override + protected McpClientTransport createMcpTransport() { + ServerParameters stdioParams = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + + return new StdioClientTransport(stdioParams); + } + + @Test + void customErrorHandlerShouldReceiveErrors() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference receivedError = new AtomicReference<>(); + + McpClientTransport transport = createMcpTransport(); + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + ((StdioClientTransport) transport).setStdErrorHandler(error -> { + receivedError.set(error); + latch.countDown(); + }); + + String errorMessage = "Test error"; + ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + + assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); + + StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java similarity index 91% rename from mcp/src/test/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransportTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index f50aeb354..294056fbe 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -1,26 +1,16 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.client.transport; +package io.modelcontextprotocol.client.transport; import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,8 +21,6 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; import org.springframework.http.codec.ServerSentEvent; import static org.assertj.core.api.Assertions.assertThat; diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java similarity index 91% rename from mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java rename to mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java index 869d34a45..b9a19de6c 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -1,43 +1,30 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.time.Duration; import java.util.List; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.PromptMessage; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -49,7 +36,8 @@ * @author Christian Tzolov */ // KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncServerTests { +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -80,7 +68,8 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java new file mode 100644 index 000000000..4b4fc434f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -0,0 +1,467 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransportProvider} implementations. + * + * @author Christian Tzolov + */ +public abstract class AbstractMcpAsyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServerTransportProvider createMcpTransportProvider(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java similarity index 90% rename from mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java rename to mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java index f627bdccd..16bc2d6e4 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -1,40 +1,27 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.util.List; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptResult; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.PromptMessage; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -46,7 +33,8 @@ * @author Christian Tzolov */ // KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncServerTests { +@Deprecated +public abstract class AbstractMcpSyncServerDeprecatedTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -78,7 +66,7 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java new file mode 100644 index 000000000..17feb36e5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -0,0 +1,439 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransportProvider} implementations. + * + * @author Christian Tzolov + */ +public abstract class AbstractMcpSyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServerTransportProvider createMcpTransportProvider(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt specification must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specificaiton) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchage, roots) -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, (exchange, roots) -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java new file mode 100644 index 000000000..208bcb71b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.server; + +public abstract class BaseMcpAsyncServerTests { + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/McpServerProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java similarity index 87% rename from mcp/src/test/java/org/springframework/ai/mcp/server/McpServerProtocolVersionTests.java rename to mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index f557921d2..97358723f 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/McpServerProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -1,29 +1,16 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server; +package io.modelcontextprotocol.server; import java.util.List; import java.util.UUID; +import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; -import org.springframework.ai.mcp.MockMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 000000000..2c80d45c6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java new file mode 100644 index 000000000..9de186b4b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java new file mode 100644 index 000000000..8cdd08c5d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java new file mode 100644 index 000000000..60dc53a4a --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java new file mode 100644 index 000000000..db95db07b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java new file mode 100644 index 000000000..27ff53c93 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java new file mode 100644 index 000000000..149f72819 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java new file mode 100644 index 000000000..a71c38493 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransportIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java similarity index 83% rename from mcp/src/test/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransportIntegrationTests.java rename to mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java index bc6bc06d3..4a292da31 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/HttpServletSseServerTransportIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java @@ -1,19 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.time.Duration; import java.util.List; @@ -22,6 +10,21 @@ import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; @@ -29,25 +32,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.transport.HttpClientSseClientTransport; -import org.springframework.ai.mcp.server.McpServer; -import org.springframework.ai.mcp.server.McpServerFeatures; -import org.springframework.ai.mcp.spec.McpError; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.InitializeResult; -import org.springframework.ai.mcp.spec.McpSchema.Role; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.ServerCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.Tool; import org.springframework.web.client.RestClient; import static org.assertj.core.api.Assertions.assertThat; @@ -55,8 +41,6 @@ public class HttpServletSseServerTransportIntegrationTests { - private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportIntegrationTests.class); - private static final int PORT = 8184; private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -239,7 +223,7 @@ void testToolCallSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -274,7 +258,7 @@ void testToolListChangeHandlingSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -290,7 +274,7 @@ void testToolListChangeHandlingSuccess() { var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { String response = RestClient.create() .get() - .uri("https://github.com/spring-projects-experimental/spring-ai-mcp/blob/main/README.md") + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java new file mode 100644 index 000000000..fd8a4e9f9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -0,0 +1,497 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public class HttpServletSseServerTransportProviderIntegrationTests { + + private static final int PORT = 8185; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext("", baseDir); + + // Create and configure the transport provider + mcpServerTransportProvider = new HttpServletSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(mcpServerTransportProvider); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + try { + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + @Disabled + void testCreateMessageWithoutSamplingCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsNotifciationWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleHandlers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java new file mode 100644 index 000000000..14987b5ac --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -0,0 +1,227 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Disabled +class StdioServerTransportProviderTests { + + private final PrintStream originalOut = System.out; + + private final PrintStream originalErr = System.err; + + private ByteArrayOutputStream testErr; + + private PrintStream testOutPrintStream; + + private StdioServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.Factory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + testErr = new ByteArrayOutputStream(); + + testOutPrintStream = new PrintStream(testErr, true); + System.setOut(testOutPrintStream); + System.setErr(testOutPrintStream); + + objectMapper = new ObjectMapper(); + + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); + + // Configure mock behavior + when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); + + transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + } + + @AfterEach + void tearDown() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (testOutPrintStream != null) { + testOutPrintStream.close(); + } + System.setOut(originalOut); + System.setErr(originalErr); + } + + @Test + void shouldCreateSessionWhenSessionFactoryIsSet() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Verify session was created with a transport + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleIncomingMessages() throws Exception { + + String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + // Set up a real session to capture the message + AtomicReference capturedMessage = new AtomicReference<>(); + CountDownLatch messageLatch = new CountDownLatch(1); + + McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())).thenAnswer(invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; + + // Set session factory + transportProvider.setSessionFactory(realSessionFactory); + + // Wait for the message to be processed using the latch + StepVerifier.create(Mono.fromCallable(() -> messageLatch.await(100, TimeUnit.SECONDS)).flatMap(success -> { + if (!success) { + return Mono.error(new AssertionError("Timeout waiting for message processing")); + } + return Mono.just(capturedMessage.get()); + })).assertNext(message -> { + assertThat(message).isNotNull(); + assertThat(message).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) message; + assertThat(request.method()).isEqualTo("test"); + assertThat(request.id()).isEqualTo(1); + }).verifyComplete(); + } + + @Test + void shouldNotifyClients() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Send notification + String method = "testNotification"; + Map params = Map.of("key", "value"); + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldCloseGracefully() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleMultipleCloseGracefullyCalls() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully multiple times + StepVerifier + .create(transportProvider.closeGracefully() + .then(transportProvider.closeGracefully()) + .then(transportProvider.closeGracefully())) + .verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleNotificationBeforeSessionFactoryIsSet() { + + transportProvider = new StdioServerTransportProvider(objectMapper); + // Send notification before setting session factory + StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + }); + } + + @Test + void shouldHandleInvalidJsonMessage() throws Exception { + + // Write an invalid JSON message to the input stream + String jsonMessage = "{invalid json}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + + // Set up a session factory + transportProvider.setSessionFactory(sessionFactory); + + // Use StepVerifier with a timeout to wait for the error to be processed + StepVerifier + .create(Mono.delay(java.time.Duration.ofMillis(500)).then(Mono.fromCallable(() -> testErr.toString()))) + .assertNext(errorOutput -> assertThat(errorOutput).contains("Error processing inbound message")) + .verifyComplete(); + } + + @Test + void shouldHandleSessionClose() throws Exception { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close the transport provider + transportProvider.close(); + + // Verify session was closed + verify(mockSession).closeGracefully(); + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/StdioServerTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java similarity index 86% rename from mcp/src/test/java/org/springframework/ai/mcp/server/transport/StdioServerTransportTests.java rename to mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java index c6ce443be..43e5019fc 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/StdioServerTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java @@ -1,20 +1,8 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.server.transport; +package io.modelcontextprotocol.server.transport; import java.io.ByteArrayOutputStream; import java.io.InputStream; @@ -23,6 +11,8 @@ import java.util.Map; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -30,9 +20,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/mcp/src/test/java/org/springframework/ai/mcp/spec/DefaultMcpSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java similarity index 82% rename from mcp/src/test/java/org/springframework/ai/mcp/spec/DefaultMcpSessionTests.java rename to mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 9af6513e6..79a1d0d92 100644 --- a/mcp/src/test/java/org/springframework/ai/mcp/spec/DefaultMcpSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -1,25 +1,14 @@ /* * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ -package org.springframework.ai.mcp.spec; +package io.modelcontextprotocol.spec; import java.time.Duration; import java.util.Map; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.MockMcpTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -29,20 +18,18 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import org.springframework.ai.mcp.MockMcpTransport; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** - * Test suite for {@link DefaultMcpSession} that verifies its JSON-RPC message handling, + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, * request-response correlation, and notification processing. * * @author Christian Tzolov */ -class DefaultMcpSessionTests { +class McpClientSessionTests { - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSessionTests.class); + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); private static final Duration TIMEOUT = Duration.ofSeconds(5); @@ -52,14 +39,14 @@ class DefaultMcpSessionTests { private static final String ECHO_METHOD = "echo"; - private DefaultMcpSession session; + private McpClientSession session; private MockMcpTransport transport; @BeforeEach void setUp() { transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); } @@ -72,11 +59,11 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new DefaultMcpSession(null, transport, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requstTimeout can not be null"); - assertThatThrownBy(() -> new DefaultMcpSession(TIMEOUT, null, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("transport can not be null"); } @@ -150,10 +137,10 @@ void testSendNotification() { @Test void testRequestHandling() { String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, + Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, requestHandlers, Map.of()); + session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -173,7 +160,7 @@ void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); // Simulate incoming notification from the server diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java new file mode 100644 index 000000000..1b8adc33b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -0,0 +1,602 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.api.Test; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Christian Tzolov + */ +public class McpSchemaTests { + + ObjectMapper mapper = new ObjectMapper(); + + // Content Types Tests + + @Test + void testTextContent() throws Exception { + McpSchema.TextContent test = new McpSchema.TextContent("XXX"); + String value = mapper.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"text","text":"XXX"}""")); + } + + @Test + void testTextContentDeserialization() throws Exception { + McpSchema.TextContent textContent = mapper.readValue(""" + {"type":"text","text":"XXX"}""", McpSchema.TextContent.class); + + assertThat(textContent).isNotNull(); + assertThat(textContent.type()).isEqualTo("text"); + assertThat(textContent.text()).isEqualTo("XXX"); + } + + @Test + void testContentDeserializationWrongType() throws Exception { + + assertThatThrownBy(() -> mapper.readValue(""" + {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) + .isInstanceOf(InvalidTypeIdException.class) + .hasMessageContaining( + "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [image, resource, text]"); + } + + @Test + void testImageContent() throws Exception { + McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); + String value = mapper.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""")); + } + + @Test + void testImageContentDeserialization() throws Exception { + McpSchema.ImageContent imageContent = mapper.readValue(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""", McpSchema.ImageContent.class); + assertThat(imageContent).isNotNull(); + assertThat(imageContent.type()).isEqualTo("image"); + assertThat(imageContent.data()).isEqualTo("base64encodeddata"); + assertThat(imageContent.mimeType()).isEqualTo("image/png"); + } + + @Test + void testEmbeddedResource() throws Exception { + McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", + "text/plain", "Sample resource content"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = mapper.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""")); + } + + @Test + void testEmbeddedResourceDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("text/plain"); + assertThat(((TextResourceContents) embeddedResource.resource()).text()).isEqualTo("Sample resource content"); + } + + @Test + void testEmbeddedResourceWithBlobContents() throws Exception { + McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", + "application/octet-stream", "base64encodedblob"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = mapper.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""")); + } + + @Test + void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("application/octet-stream"); + assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).blob()) + .isEqualTo("base64encodedblob"); + } + + // JSON-RPC Message Types Tests + + @Test + void testJSONRPCRequest() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, + params); + + String value = mapper.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCNotification() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + "notification_method", params); + + String value = mapper.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"notification_method","params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCResponse() throws Exception { + Map result = new HashMap<>(); + result.put("result_key", "result_value"); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); + + String value = mapper.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"result":{"result_key":"result_value"}}""")); + } + + @Test + void testJSONRPCResponseWithError() throws Exception { + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); + + String value = mapper.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); + } + + // Initialization Tests + + @Test + void testInitializeRequest() throws Exception { + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .roots(true) + .sampling() + .build(); + + McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + + McpSchema.InitializeRequest request = new McpSchema.InitializeRequest("2024-11-05", capabilities, clientInfo); + + String value = mapper.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"roots":{"listChanged":true},"sampling":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}}""")); + } + + @Test + void testInitializeResult() throws Exception { + McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder() + .logging() + .prompts(true) + .resources(true, true) + .tools(true) + .build(); + + McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); + + McpSchema.InitializeResult result = new McpSchema.InitializeResult("2024-11-05", capabilities, serverInfo, + "Server initialized successfully"); + + String value = mapper.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{"listChanged":true},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"test-server","version":"1.0.0"},"instructions":"Server initialized successfully"}""")); + } + + // Resource Tests + + @Test + void testResource() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", + "text/plain", annotations); + + String value = mapper.writeValueAsString(resource); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","annotations":{"audience":["user","assistant"],"priority":0.8}}""")); + } + + @Test + void testResourceTemplate() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations(Arrays.asList(McpSchema.Role.USER), 0.5); + + McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", + "A test resource template", "text/plain", annotations); + + String value = mapper.writeValueAsString(template); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uriTemplate":"resource://{param}/test","name":"Test Template","description":"A test resource template","mimeType":"text/plain","annotations":{"audience":["user"],"priority":0.5}}""")); + } + + @Test + void testListResourcesResult() throws Exception { + McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", + "First test resource", "text/plain", null); + + McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", + "Second test resource", "application/json", null); + + McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), + "next-cursor"); + + String value = mapper.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resources":[{"uri":"resource://test1","name":"Test Resource 1","description":"First test resource","mimeType":"text/plain"},{"uri":"resource://test2","name":"Test Resource 2","description":"Second test resource","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testListResourceTemplatesResult() throws Exception { + McpSchema.ResourceTemplate template1 = new McpSchema.ResourceTemplate("resource://{param}/test1", + "Test Template 1", "First test template", "text/plain", null); + + McpSchema.ResourceTemplate template2 = new McpSchema.ResourceTemplate("resource://{param}/test2", + "Test Template 2", "Second test template", "application/json", null); + + McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( + Arrays.asList(template1, template2), "next-cursor"); + + String value = mapper.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resourceTemplates":[{"uriTemplate":"resource://{param}/test1","name":"Test Template 1","description":"First test template","mimeType":"text/plain"},{"uriTemplate":"resource://{param}/test2","name":"Test Template 2","description":"Second test template","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testReadResourceRequest() throws Exception { + McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test"); + + String value = mapper.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"resource://test"}""")); + } + + @Test + void testReadResourceResult() throws Exception { + McpSchema.TextResourceContents contents1 = new McpSchema.TextResourceContents("resource://test1", "text/plain", + "Sample text content"); + + McpSchema.BlobResourceContents contents2 = new McpSchema.BlobResourceContents("resource://test2", + "application/octet-stream", "base64encodedblob"); + + McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2)); + + String value = mapper.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"contents":[{"uri":"resource://test1","mimeType":"text/plain","text":"Sample text content"},{"uri":"resource://test2","mimeType":"application/octet-stream","blob":"base64encodedblob"}]}""")); + } + + // Prompt Tests + + @Test + void testPrompt() throws Exception { + McpSchema.PromptArgument arg1 = new McpSchema.PromptArgument("arg1", "First argument", true); + + McpSchema.PromptArgument arg2 = new McpSchema.PromptArgument("arg2", "Second argument", false); + + McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "A test prompt", Arrays.asList(arg1, arg2)); + + String value = mapper.writeValueAsString(prompt); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-prompt","description":"A test prompt","arguments":[{"name":"arg1","description":"First argument","required":true},{"name":"arg2","description":"Second argument","required":false}]}""")); + } + + @Test + void testPromptMessage() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Hello, world!"); + + McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); + + String value = mapper.writeValueAsString(message); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"role":"user","content":{"type":"text","text":"Hello, world!"}}""")); + } + + @Test + void testListPromptsResult() throws Exception { + McpSchema.PromptArgument arg = new McpSchema.PromptArgument("arg", "An argument", true); + + McpSchema.Prompt prompt1 = new McpSchema.Prompt("prompt1", "First prompt", Collections.singletonList(arg)); + + McpSchema.Prompt prompt2 = new McpSchema.Prompt("prompt2", "Second prompt", Collections.emptyList()); + + McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), + "next-cursor"); + + String value = mapper.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"prompts":[{"name":"prompt1","description":"First prompt","arguments":[{"name":"arg","description":"An argument","required":true}]},{"name":"prompt2","description":"Second prompt","arguments":[]}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testGetPromptRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("arg1", "value1"); + arguments.put("arg2", 42); + + McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); + + assertThat(mapper.readValue(""" + {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) + .isEqualTo(request); + } + + @Test + void testGetPromptResult() throws Exception { + McpSchema.TextContent content1 = new McpSchema.TextContent("System message"); + McpSchema.TextContent content2 = new McpSchema.TextContent("User message"); + + McpSchema.PromptMessage message1 = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, content1); + + McpSchema.PromptMessage message2 = new McpSchema.PromptMessage(McpSchema.Role.USER, content2); + + McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", + Arrays.asList(message1, message2)); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"description":"A test prompt result","messages":[{"role":"assistant","content":{"type":"text","text":"System message"}},{"role":"user","content":{"type":"text","text":"User message"}}]}""")); + } + + // Tool Tests + + @Test + void testTool() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson); + + String value = mapper.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); + } + + @Test + void testCallToolRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("name", "test"); + arguments.put("value", 42); + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + + @Test + void testCallToolResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); + + McpSchema.CallToolResult result = new McpSchema.CallToolResult(Collections.singletonList(content), false); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + // Sampling Tests + + @Test + void testCreateMessageRequest() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("User message"); + + McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); + + McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); + + McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, + 0.7, 0.9); + + Map metadata = new HashMap<>(); + metadata.put("session", "test-session"); + + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"messages":[{"role":"user","content":{"type":"text","text":"User message"}}],"modelPreferences":{"hints":[{"name":"gpt-4"}],"costPriority":0.3,"speedPriority":0.7,"intelligencePriority":0.9},"systemPrompt":"You are a helpful assistant","includeContext":"thisServer","temperature":0.7,"maxTokens":1000,"stopSequences":["STOP","END"],"metadata":{"session":"test-session"}}""")); + } + + @Test + void testCreateMessageResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); + + McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(content) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); + } + + // Roots Tests + + @Test + void testRoot() throws Exception { + McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root"); + + String value = mapper.writeValueAsString(root); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"file:///path/to/root","name":"Test Root"}""")); + } + + @Test + void testListRootsResult() throws Exception { + McpSchema.Root root1 = new McpSchema.Root("file:///path/to/root1", "First Root"); + + McpSchema.Root root2 = new McpSchema.Root("file:///path/to/root2", "Second Root"); + + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2)); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"roots":[{"uri":"file:///path/to/root1","name":"First Root"},{"uri":"file:///path/to/root2","name":"Second Root"}]}""")); + + } + +} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/MockMcpTransport.java b/mcp/src/test/java/org/springframework/ai/mcp/MockMcpTransport.java deleted file mode 100644 index 1622418e2..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/MockMcpTransport.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp; - -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Schedulers; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCNotification; -import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -/** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} - * interfaces. - */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { - - private final AtomicInteger inboundMessageCount = new AtomicInteger(0); - - private final Sinks.Many outgoing = Sinks.many().multicast().onBackpressureBuffer(); - - private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); - - private final Flux outboundView = outgoing.asFlux().cache(1); - - public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { - if (inbound.tryEmitNext(message).isFailure()) { - throw new RuntimeException("Failed to emit message " + message); - } - inboundMessageCount.incrementAndGet(); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (outgoing.tryEmitNext(message).isFailure()) { - return Mono.error(new RuntimeException("Can't emit outgoing message " + message)); - } - return Mono.empty(); - } - - public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { - return (JSONRPCRequest) outboundView.blockFirst(); - } - - public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() { - return (JSONRPCNotification) outboundView.blockFirst(); - } - - public McpSchema.JSONRPCMessage getLastSentMessage() { - return outboundView.blockFirst(); - } - - private volatile boolean connected = false; - - @Override - public Mono connect(Function, Mono> handler) { - if (connected) { - return Mono.error(new IllegalStateException("Already connected")); - } - connected = true; - return inbound.asFlux() - .publishOn(Schedulers.boundedElastic()) - .flatMap(message -> Mono.just(message).transform(handler)) - .doFinally(signal -> connected = false) - .then(); - } - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - connected = false; - outgoing.tryEmitComplete(); - inbound.tryEmitComplete(); - // Wait for all subscribers to complete - return Mono.empty(); - }); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests2.java b/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests2.java deleted file mode 100644 index 47e6e7335..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests2.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.attic; - -import java.time.Duration; - -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.McpSyncClient; -import org.springframework.ai.mcp.client.transport.ServerParameters; -import org.springframework.ai.mcp.client.transport.StdioClientTransport; -import org.springframework.ai.mcp.spec.McpSchema.ListResourcesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListToolsResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class ClientSessionTests2 { - - public static void main(String[] args) { - - var stdioParams = ServerParameters.builder("uv") - .args("--directory", "dir", "run", "mcp-server-sqlite", "--db-path", "~/test.db") - .build(); - - try (McpSyncClient clientSession = McpClient.using(new StdioClientTransport(stdioParams)) - .requestTimeout(Duration.ofSeconds(10)) - .sync()) { - - clientSession.initialize(); - - ListToolsResult tools = clientSession.listTools(null); - System.out.println("Tools: " + tools); - - clientSession.ping(); - - // CallToolRequest callToolRequest = new CallToolRequest("echo", - // Map.of("message", "Hello MCP Spring AI!")); - // CallToolResult callToolResult = - // clientSession.callTool(callToolRequest); - // System.out.println("Call Tool Result: " + callToolResult); - // // - // clientSession.sendRootsListChanged(); - - // Resources - ListResourcesResult resources = clientSession.listResources(null); - System.out.println("Resources Size: " + resources.resources().size()); - System.out.println("Resources: " + resources); - for (Resource resource : resources.resources()) { - System.out.println(clientSession.readResource(resource)); - - } - - // var resourceTemplate = clientSession.listResourceTemplates(null); - // System.out.println("Resource Templates: " + resourceTemplate); - - // stdioClient.awaitForExit(); - } - catch (Exception e) { - e.printStackTrace(); - } - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests3.java b/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests3.java deleted file mode 100644 index f147aa561..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/ClientSessionTests3.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.attic; - -import java.time.Duration; - -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.McpSyncClient; -import org.springframework.ai.mcp.client.transport.ServerParameters; -import org.springframework.ai.mcp.client.transport.StdioClientTransport; -import org.springframework.ai.mcp.spec.McpSchema.ListResourcesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListToolsResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class ClientSessionTests3 { - - public static void main(String[] args) { - - var stdioParams = ServerParameters.builder("uv") - .args("--directory", "dir", "run", "mcp-server-sqlite", "--db-path", "~/test.db") - .build(); - - McpSyncClient clientSession = null; - try { - - clientSession = McpClient.using(new StdioClientTransport(stdioParams)) - .requestTimeout(Duration.ofSeconds(10)) - .sync(); - - clientSession.initialize(); - - ListToolsResult tools = clientSession.listTools(null); - System.out.println("Tools: " + tools); - - clientSession.ping(); - - // Resources - ListResourcesResult resources = clientSession.listResources(null); - System.out.println("Resources Size: " + resources.resources().size()); - System.out.println("Resources: " + resources); - for (Resource resource : resources.resources()) { - System.out.println(clientSession.readResource(resource)); - - } - } - catch (Exception e) { - e.printStackTrace(); - } - finally { - if (clientSession != null) { - clientSession.close(); - } - } - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/SSEClient.java b/mcp/src/test/java/org/springframework/ai/mcp/attic/SSEClient.java deleted file mode 100644 index 3b7599924..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/SSEClient.java +++ /dev/null @@ -1,183 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.attic; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicReference; -import java.util.regex.Pattern; -import java.util.stream.Stream; - -public class SSEClient { - - private final HttpClient httpClient; - - private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); - - private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); - - private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); - - public SSEClient() { - this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(10)).build(); - } - - public void subscribe(String url, SSEEventHandler eventHandler) { - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - CompletableFuture future = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()) - .thenAccept(response -> { - if (response.statusCode() == 200) { - processSSEStream(response, eventHandler); - } - else { - throw new RuntimeException("Failed to connect to SSE stream: " + response.statusCode()); - } - }); - - // Handle errors - future.exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); - } - - private void processSSEStream(HttpResponse> response, SSEEventHandler eventHandler) { - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); // default - // event - // type - - response.body().forEach(line -> { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SSEEvent event = parseEvent(eventData, currentEventId.get(), currentEventType.get()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } - } - } - }); - } - - private SSEEvent parseEvent(String eventData, String eventId, String eventType) { - return new SSEEvent(eventId, eventType, eventData.trim()); - } - - // Event handler interface - public interface SSEEventHandler { - - void onEvent(SSEEvent event); - - void onError(Throwable error); - - } - - // SSE Event class - public static class SSEEvent { - - private final String id; - - private final String type; - - private final String data; - - public SSEEvent(String id, String type, String data) { - this.id = id; - this.type = type; - this.data = data; - } - - public String getId() { - return id; - } - - public String getType() { - return type; - } - - public String getData() { - return data; - } - - @Override - public String toString() { - return "SSEEvent{" + "id='" + id + '\'' + ", type='" + type + '\'' + ", data='" + data + '\'' + '}'; - } - - } - - public static void main(String[] args) { - SSEClient sseClient = new SSEClient(); - - var eventHandler = new SSEEventHandler() { - @Override - public void onEvent(SSEEvent event) { - System.out.println("Received event: " + event); - } - - @Override - public void onError(Throwable error) { - System.out.println("Error: " + error); - } - }; - sseClient.subscribe("http://localhost:8080/sse", eventHandler); - - try { - Thread.sleep(100000); - } - catch (InterruptedException e) { - e.printStackTrace(); - } - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/ServerParametersParser.java b/mcp/src/test/java/org/springframework/ai/mcp/attic/ServerParametersParser.java deleted file mode 100644 index e093723fd..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/ServerParametersParser.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.attic; - -import java.util.List; -import java.util.Map; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.springframework.ai.mcp.attic.ServerParametersParser.McpServerConfigurations.McpServerConfiguration; - -/** - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ - -public class ServerParametersParser { - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record McpServerConfigurations(// @formatter:off - @JsonProperty("mcpServers") Map mcpServers) { - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record McpServerConfiguration( - @JsonProperty("command") String command, - @JsonProperty("args") List args, - @JsonProperty("env") Map env) { - } - } // @formatter:on - - public static void main(String[] args) throws Exception { - ObjectMapper objectMapper = new ObjectMapper(); - - // Assuming the JSON is in a file named "servers.json" - String jsonInput = """ - { - "mcpServers": { - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"] - }, - "git": { - "command": "uvx", - "args": ["mcp-server-git", "--repository", "path/to/git/repo"] - }, - "github": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-github"], - "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "" - } - }, - "postgres": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-postgres", "postgresql://localhost/mydb"] - } - } - } - """; - - // Deserialize JSON into McpServers - McpServerConfigurations servers = objectMapper.readValue(jsonInput, McpServerConfigurations.class); - - // Access individual servers - for (Map.Entry entry : servers.mcpServers().entrySet()) { - System.out.println("Key: " + entry.getKey()); - System.out.println("Command: " + entry.getValue().command()); - System.out.println("Arguments: " + String.join(", ", entry.getValue().args())); - System.out.println("Environment Variables: " + entry.getValue().env()); - System.out.println("-----------"); - } - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/_DemoServer.java_ b/mcp/src/test/java/org/springframework/ai/mcp/attic/_DemoServer.java_ deleted file mode 100644 index df1fc0b61..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/_DemoServer.java_ +++ /dev/null @@ -1,61 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.attic; - -import java.util.List; -import java.util.Map; - -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.netty.http.server.HttpServer; - -import org.springframework.ai.mcp.server.McpServer; -import org.springframework.ai.mcp.server.transport.SseServerTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; - -/** - * - * https://docs.spring.io/spring-framework/reference/web/webflux-functional.html#webflux-fn-running - * https://docs.spring.io/spring-framework/reference/web/webflux/reactive-spring.html#webflux-httphandler - * - * @author Christian Tzolov - */ - -public class DemoServer { - - public static void main(String[] args) { - SseServerTransport transport = new SseServerTransport(new ObjectMapper(), "/mcp/message"); - - var mcpServer = McpServer.using(transport) - .serverInfo("Weather Forecast", "1.0.0") - .tool(new McpSchema.Tool("weather", "Weather forecast tool by location", "{ \"type\": \"object\" }"), - (arguments) -> { - String city = (String) arguments.get("city"); - return new CallToolResult(List.of(), false); - }) - .async(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - HttpServer httpServer = HttpServer.create().port(8080).handle(adapter); - httpServer.bindNow().onDispose().block(); - - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/attic/_MainSSE.java_ b/mcp/src/test/java/org/springframework/ai/mcp/attic/_MainSSE.java_ deleted file mode 100644 index e4c3004a2..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/attic/_MainSSE.java_ +++ /dev/null @@ -1,81 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.attic; - -import java.time.Duration; -import java.util.Map; - -import org.springframework.ai.mcp.client.McpClient; -import org.springframework.ai.mcp.client.McpSyncClient; -import org.springframework.ai.mcp.client.transport.SseClientTransport; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ListPromptsResult; -import org.springframework.ai.mcp.spec.McpSchema.ListResourcesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListToolsResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.web.reactive.function.client.WebClient; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ - -public class MainSSE { - - public static void main(String[] args) { - - try (McpSyncClient mcpClient = McpClient - .using(new SseClientTransport(WebClient.builder().baseUrl("http://localhost:3001"))) - .requestTimeout(Duration.ofSeconds(1000)) - .sync()) { - - mcpClient.initialize(); - - ListToolsResult tools = mcpClient.listTools(); - System.out.println("Tools: " + tools); - - mcpClient.ping(); - - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", "Hello MCP Spring AI!")); - CallToolResult callToolResult = mcpClient.callTool(callToolRequest); - System.out.println("Call Tool Result: " + callToolResult); - - // mcpClient.sendRootsListChanged(); - - // Resources - ListResourcesResult resources = mcpClient.listResources(); - System.out.println("Resources Size: " + resources.resources().size()); - System.out.println("Resources: " + resources); - for (Resource resource : resources.resources()) { - System.out.println(mcpClient.readResource(resource)); - } - - var resourceTemplate = mcpClient.listResourceTemplates(); - System.out.println("Resource Templates: " + resourceTemplate); - - ListPromptsResult prompts = mcpClient.listPrompts(); - for (var prompt : prompts.prompts()) { - System.out.println("Prompt: " + prompt); - } - - } - catch (Exception e) { - e.printStackTrace(); - } - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java deleted file mode 100644 index 0d7fd0bdf..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpAsyncClientTests.java +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageRequest; -import org.springframework.ai.mcp.spec.McpSchema.CreateMessageResult; -import org.springframework.ai.mcp.spec.McpSchema.GetPromptRequest; -import org.springframework.ai.mcp.spec.McpSchema.Prompt; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.SubscribeRequest; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpSchema.UnsubscribeRequest; -import org.springframework.ai.mcp.spec.McpTransport; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncClient} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncClientTests { - - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - - private static final Duration TIMEOUT = Duration.ofSeconds(20); - - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - - abstract protected ClientMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(TIMEOUT) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListTools() { - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); - } - - @Test - void testPing() { - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); - } - - @Test - void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); - } - - @Test - void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); - } - - @Test - void testListResources() { - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); - } - - @Test - void testListPrompts() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); - } - - @Test - void testGetPrompt() { - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); - } - - @Test - void testRootsListChanged() { - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - @Disabled - void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); - } - - // @Test - void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(); - // Trigger notifications - client.sendResourcesListChanged().block(); - client.promptListChangedNotification().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) - .roots(true) - .sampling() - .build(); - - Function> samplingHandler = request -> Mono - .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(TIMEOUT) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); - - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java deleted file mode 100644 index c046b9eb2..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/AbstractMcpSyncClientTests.java +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import org.springframework.ai.mcp.spec.ClientMcpTransport; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.ClientCapabilities; -import org.springframework.ai.mcp.spec.McpSchema.ListResourceTemplatesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListResourcesResult; -import org.springframework.ai.mcp.spec.McpSchema.ListToolsResult; -import org.springframework.ai.mcp.spec.McpSchema.ReadResourceResult; -import org.springframework.ai.mcp.spec.McpSchema.Resource; -import org.springframework.ai.mcp.spec.McpSchema.Root; -import org.springframework.ai.mcp.spec.McpSchema.SubscribeRequest; -import org.springframework.ai.mcp.spec.McpSchema.TextContent; -import org.springframework.ai.mcp.spec.McpSchema.Tool; -import org.springframework.ai.mcp.spec.McpSchema.UnsubscribeRequest; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for MCP Client Session functionality. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncClientTests { - - private McpSyncClient mcpSyncClient; - - private static final Duration TIMEOUT = Duration.ofSeconds(10); - - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - - protected ClientMcpTransport mcpTransport; - - abstract protected ClientMcpTransport createMcpTransport(); - - abstract protected void onStart(); - - abstract protected void onClose(); - - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); - - assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(TIMEOUT) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); - mcpSyncClient.initialize(); - }).doesNotThrowAnyException(); - } - - @AfterEach - void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } - onClose(); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListTools() { - ListToolsResult tools = mcpSyncClient.listTools(null); - - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }); - } - - @Test - void testCallTools() { - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - - assertThat(toolResult).isNotNull().satisfies(result -> { - - assertThat(result.content()).hasSize(1); - - TextContent content = (TextContent) result.content().get(0); - - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); - }); - } - - @Test - void testPing() { - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); - } - - @Test - void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - } - - @Test - void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); - } - - @Test - void testRootsListChanged() { - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); - } - - @Test - void testListResources() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - } - - @Test - void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); - } - - @Test - void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); - } - - @Test - void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); - } - - @Test - void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - } - - @Test - void testReadResource() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } - } - - @Test - void testListResourceTemplates() { - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - } - - // @Test - void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - // Trigger notifications - client.sendResourcesListChanged(); - client.promptListChangedNotification(); - client.close(); - }).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(TIMEOUT) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpAsyncClientTests.java deleted file mode 100644 index 94c9b41ca..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpAsyncClientTests.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.client.transport.ServerParameters; -import org.springframework.ai.mcp.client.transport.StdioClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; - -/** - * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - - @Override - protected ClientMcpTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); - return new StdioClientTransport(stdioParams); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpSyncClientTests.java deleted file mode 100644 index 962bc2ac6..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/client/StdioMcpSyncClientTests.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client; - -import java.util.concurrent.atomic.AtomicReference; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.client.transport.ServerParameters; -import org.springframework.ai.mcp.client.transport.StdioClientTransport; -import org.springframework.ai.mcp.spec.ClientMcpTransport; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { - - @Override - protected ClientMcpTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); - - return new StdioClientTransport(stdioParams); - } - - @Test - void customErrorHandlerShouldReceiveErrors() { - AtomicReference receivedError = new AtomicReference<>(); - - ((StdioClientTransport) mcpTransport).setErrorHandler(error -> receivedError.set(error)); - - String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().tryEmitNext(errorMessage); - - assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpAsyncServerTests.java deleted file mode 100644 index 584938beb..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpAsyncServerTests.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.server.transport.HttpServletSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -/** - * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpSyncServerTests.java deleted file mode 100644 index 2d3390de4..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/ServletSseMcpSyncServerTests.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.server.transport.HttpServletSseServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -/** - * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpAsyncServerTests.java deleted file mode 100644 index 9f15658b6..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpAsyncServerTests.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.server; - -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.server.transport.StdioServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -/** - * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpSyncServerTests.java deleted file mode 100644 index 6bcc4bedf..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/StdioMcpSyncServerTests.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.server; - -import org.junit.jupiter.api.Timeout; - -import org.springframework.ai.mcp.server.transport.StdioServerTransport; -import org.springframework.ai.mcp.spec.ServerMcpTransport; - -/** - * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); - } - -} diff --git a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java b/mcp/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java deleted file mode 100644 index 96690dbd8..000000000 --- a/mcp/src/test/java/org/springframework/ai/mcp/server/transport/BlockingInputStream.java +++ /dev/null @@ -1,81 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.server.transport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -public class BlockingInputStream extends InputStream { - - private final BlockingQueue queue = new LinkedBlockingQueue<>(); - - private volatile boolean completed = false; - - private volatile boolean closed = false; - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException("Stream is closed"); - } - - try { - Integer value = queue.poll(); - if (value == null) { - if (completed) { - return -1; - } - value = queue.take(); // Blocks until data is available - if (value == null && completed) { - return -1; - } - } - return value; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Read interrupted", e); - } - } - - public void write(int b) { - if (!closed && !completed) { - queue.offer(b); - } - } - - public void write(byte[] data) { - if (!closed && !completed) { - for (byte b : data) { - queue.offer((int) b & 0xFF); - } - } - } - - public void complete() { - this.completed = true; - } - - @Override - public void close() { - this.closed = true; - this.completed = true; - this.queue.clear(); - } - -} \ No newline at end of file diff --git a/migration-0.8.0.md b/migration-0.8.0.md new file mode 100644 index 000000000..3ba29a10b --- /dev/null +++ b/migration-0.8.0.md @@ -0,0 +1,328 @@ +# MCP Java SDK Migration Guide: 0.7.0 to 0.8.0 + +This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. + +The 0.8.0 refactoring introduces a session-based architecture for server-side MCP implementations. +It improves the SDK's ability to handle multiple concurrent client connections and provides an API better aligned with the MCP specification. +The main changes include: + +1. Introduction of a session-based architecture +2. New transport provider abstraction +3. Exchange objects for client interaction +4. Renamed and reorganized interfaces +5. Updated handler signatures + +## Breaking Changes + +### 1. Interface Renaming + +Several interfaces have been renamed to better reflect their roles: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ClientMcpTransport` | `McpClientTransport` | +| `ServerMcpTransport` | `McpServerTransport` | +| `DefaultMcpSession` | `McpClientSession`, `McpServerSession` | + +### 2. New Server Transport Architecture + +The most significant change is the introduction of the `McpServerTransportProvider` interface, which replaces direct usage of `ServerMcpTransport` when creating servers. This new pattern separates the concerns of: + +1. **Transport Provider**: Manages connections with clients and creates individual transports for each connection +2. **Server Transport**: Handles communication with a specific client connection + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ServerMcpTransport` | `McpServerTransportProvider` + `McpServerTransport` | +| Direct transport usage | Session-based transport usage | + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +### 3. Handler Method Signature Changes + +Tool, resource, and prompt handlers now receive an additional `exchange` parameter that provides access to client capabilities and methods to interact with the client: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `(args) -> result` | `(exchange, args) -> result` | + +The exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide context for the current session and access to session-specific operations. + +#### Before (0.7.0): + +```java +// Tool handler +.tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, req -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, req -> new GetPromptResult("Analysis prompt")) +``` + +#### After (0.8.0): + +```java +// Tool handler +.tool(calculatorTool, (exchange, args) -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, (exchange, req) -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, (exchange, req) -> new GetPromptResult("Analysis prompt")) +``` + +### 4. Registration vs. Specification + +The naming convention for handlers has changed from "Registration" to "Specification": + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `AsyncToolRegistration` | `AsyncToolSpecification` | +| `SyncToolRegistration` | `SyncToolSpecification` | +| `AsyncResourceRegistration` | `AsyncResourceSpecification` | +| `SyncResourceRegistration` | `SyncResourceSpecification` | +| `AsyncPromptRegistration` | `AsyncPromptSpecification` | +| `SyncPromptRegistration` | `SyncPromptSpecification` | + +### 5. Roots Change Handler Updates + +The roots change handlers now receive an exchange parameter: + +#### Before (0.7.0): + +```java +.rootsChangeConsumers(List.of( + roots -> { + // Process roots + } +)) +``` + +#### After (0.8.0): + +```java +.rootsChangeHandlers(List.of( + (exchange, roots) -> { + // Process roots with access to exchange + } +)) +``` + +### 6. Server Creation Method Changes + +The `McpServer` factory methods now accept `McpServerTransportProvider` instead of `ServerMcpTransport`: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `McpServer.async(ServerMcpTransport)` | `McpServer.async(McpServerTransportProvider)` | +| `McpServer.sync(ServerMcpTransport)` | `McpServer.sync(McpServerTransportProvider)` | + +The method names for creating servers have been updated: + +Root change handlers now receive an exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `rootsChangeConsumers(List>>)` | `rootsChangeHandlers(List>>)` | +| `rootsChangeConsumer(Consumer>)` | `rootsChangeHandler(BiConsumer>)` | + +### 7. Direct Server Methods Moving to Exchange + +Several methods that were previously available directly on the server are now accessed through the exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `server.listRoots()` | `exchange.listRoots()` | +| `server.createMessage()` | `exchange.createMessage()` | +| `server.getClientCapabilities()` | `exchange.getClientCapabilities()` | +| `server.getClientInfo()` | `exchange.getClientInfo()` | + +The direct methods are deprecated and will be removed in 0.9.0: + +- `McpSyncServer.listRoots()` +- `McpSyncServer.getClientCapabilities()` +- `McpSyncServer.getClientInfo()` +- `McpSyncServer.createMessage()` +- `McpAsyncServer.listRoots()` +- `McpAsyncServer.getClientCapabilities()` +- `McpAsyncServer.getClientInfo()` +- `McpAsyncServer.createMessage()` + +## Deprecation Notices + +The following components are deprecated in 0.8.0 and will be removed in 0.9.0: + +- `ClientMcpTransport` interface (use `McpClientTransport` instead) +- `ServerMcpTransport` interface (use `McpServerTransport` instead) +- `DefaultMcpSession` class (use `McpClientSession` instead) +- `WebFluxSseServerTransport` class (use `WebFluxSseServerTransportProvider` instead) +- `WebMvcSseServerTransport` class (use `WebMvcSseServerTransportProvider` instead) +- `StdioServerTransport` class (use `StdioServerTransportProvider` instead) +- All `*Registration` classes (use corresponding `*Specification` classes instead) +- Direct server methods for client interaction (use exchange object instead) + +## Migration Examples + +### Example 1: Creating a Server + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +var server = McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + .rootsChangeConsumers(List.of( + roots -> System.out.println("Roots changed: " + roots) + )) + .build(); + +// Get client capabilities directly from server +ClientCapabilities capabilities = server.getClientCapabilities(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +var server = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, (exchange, args) -> { + // Get client capabilities from exchange + ClientCapabilities capabilities = exchange.getClientCapabilities(); + return new CallToolResult("Result: " + calculate(args)); + }) + .rootsChangeHandlers(List.of( + (exchange, roots) -> System.out.println("Roots changed: " + roots) + )) + .build(); +``` + +### Example 2: Implementing a Tool with Client Interaction + +#### Before (0.7.0): + +```java +McpServerFeatures.SyncToolRegistration tool = new McpServerFeatures.SyncToolRegistration( + new Tool("weather", "Get weather information", schema), + args -> { + String location = (String) args.get("location"); + // Cannot interact with client from here + return new CallToolResult("Weather for " + location + ": Sunny"); + } +); + +var server = McpServer.sync(transport) + .tools(tool) + .build(); + +// Separate call to create a message +CreateMessageResult result = server.createMessage(new CreateMessageRequest(...)); +``` + +#### After (0.8.0): + +```java +McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new Tool("weather", "Get weather information", schema), + (exchange, args) -> { + String location = (String) args.get("location"); + + // Can interact with client directly from the tool handler + CreateMessageResult result = exchange.createMessage(new CreateMessageRequest(...)); + + return new CallToolResult("Weather for " + location + ": " + result.content()); + } +); + +var server = McpServer.sync(transportProvider) + .tools(tool) + .build(); +``` + +### Example 3: Converting Existing Registration Classes + +If you have custom implementations of the registration classes, you can convert them to the new specification classes: + +#### Before (0.7.0): + +```java +McpServerFeatures.AsyncToolRegistration toolReg = new McpServerFeatures.AsyncToolRegistration( + tool, + args -> Mono.just(new CallToolResult("Result")) +); + +McpServerFeatures.AsyncResourceRegistration resourceReg = new McpServerFeatures.AsyncResourceRegistration( + resource, + req -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +#### After (0.8.0): + +```java +// Option 1: Create new specification directly +McpServerFeatures.AsyncToolSpecification toolSpec = new McpServerFeatures.AsyncToolSpecification( + tool, + (exchange, args) -> Mono.just(new CallToolResult("Result")) +); + +// Option 2: Convert from existing registration (during transition) +McpServerFeatures.AsyncToolRegistration oldToolReg = /* existing registration */; +McpServerFeatures.AsyncToolSpecification toolSpec = oldToolReg.toSpecification(); + +// Similarly for resources +McpServerFeatures.AsyncResourceSpecification resourceSpec = new McpServerFeatures.AsyncResourceSpecification( + resource, + (exchange, req) -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +## Architecture Changes + +### Session-Based Architecture + +In 0.8.0, the MCP Java SDK introduces a session-based architecture where each client connection has its own session. This allows for better isolation between clients and more efficient resource management. + +The `McpServerTransportProvider` is responsible for creating `McpServerTransport` instances for each session, and the `McpServerSession` manages the communication with a specific client. + +### Exchange Objects + +The new exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide access to client-specific information and methods. They are passed to handler functions as the first parameter, allowing handlers to interact with the specific client that made the request. + +## Conclusion + +The changes in version 0.8.0 represent a significant architectural improvement to the MCP Java SDK. While they require some code changes, the new design provides a more flexible and maintainable foundation for building MCP applications. + +For assistance with migration or to report issues, please open an issue on the GitHub repository. diff --git a/pom.xml b/pom.xml index d7502155e..1534eb72e 100644 --- a/pom.xml +++ b/pom.xml @@ -4,27 +4,32 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.experimental + io.modelcontextprotocol.sdk mcp-parent - 0.6.0 + 0.8.0 pom - https://github.com/spring-projects-experimental/spring-ai-mcp + https://github.com/modelcontextprotocol/java-sdk - Spring AI MCP Parent - Java SDK and Spring Framework integration for the Model Context Protocol (MCP), enabling interaction with AI models and tools through a standardized interface + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + Java SDK MCP Parent + Java SDK for the Model Context Protocol (MCP), enabling interaction with AI models and tools through a standardized interface - Spring Framework - https://spring.io + Anthropic + https://www.anthropic.com - Apache 2.0 - https://www.apache.org/licenses/LICENSE-2.0.txt - repo + MIT License + http://www.opensource.org/licenses/mit-license.php @@ -36,28 +41,15 @@ Dariusz Jędrzejczyk - - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git - + Github Issues - https://github.com/spring-projects-experimental/spring-ai-mcp/issues + https://github.com/modelcontextprotocol/java-sdk/issues Github Actions - https://github.com/spring-projects-experimental/spring-ai-mcp/actions + https://github.com/modelcontextprotocol/java-sdk/actions - - - spring-snapshots - https://repo.spring.io/libs-snapshot-local - - false - - - UTF-8 @@ -66,7 +58,6 @@ 17 17 - 1.0.0-M5 3.26.3 5.10.2 5.11.0 @@ -100,15 +91,16 @@ 11.0.2 6.1.0 4.2.0 + 7.1.0 + 4.1.0 + mcp-bom mcp - mcp-transport/mcp-webflux-sse-transport - mcp-transport/mcp-webmvc-sse-transport - spring-ai-mcp - mcp-docs + mcp-spring/mcp-spring-webflux + mcp-spring/mcp-spring-webmvc mcp-test @@ -299,28 +291,42 @@ - artifactory-staging - - - spring-staging - https://repo.spring.io/libs-staging-local - - false - - - - - - artifactory-milestone - - - spring-milestones - https://repo.spring.io/libs-milestone-local - - false - - - + release + + + + org.sonatype.central + central-publishing-maven-plugin + 0.7.0 + true + + central + true + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.7 + + + sign-artifacts + verify + + sign + + + + --pinentry-mode + loopback + + ${env.MAVEN_GPG_PASSPHRASE} + + + + + + @@ -333,13 +339,6 @@ pom import - - org.springframework.ai - spring-ai-bom - ${spring-ai.version} - pom - import - @@ -355,14 +354,6 @@ true - - spring-milestones - Spring Milestones - https://repo.spring.io/milestone - - false - - - + \ No newline at end of file diff --git a/settings.xml b/settings.xml deleted file mode 100644 index 890e93070..000000000 --- a/settings.xml +++ /dev/null @@ -1,24 +0,0 @@ - - - - - spring-snapshots - ${env.ARTIFACTORY_USERNAME} - ${env.ARTIFACTORY_PASSWORD} - - - spring-staging - ${env.ARTIFACTORY_USERNAME} - ${env.ARTIFACTORY_PASSWORD} - - - spring-milestones - ${env.ARTIFACTORY_USERNAME} - ${env.ARTIFACTORY_PASSWORD} - - - - \ No newline at end of file diff --git a/spring-ai-mcp/README.md b/spring-ai-mcp/README.md deleted file mode 100644 index 8ba646f83..000000000 --- a/spring-ai-mcp/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Spring AI MCP Spring - -Spring Integration module for Model Control Protocol (MCP) that provides Spring-specific functionality for working with MCP clients. - -Find more at [Java MCP SDK](https://docs.spring.io/spring-ai-mcp/reference/spring-mcp.html) \ No newline at end of file diff --git a/spring-ai-mcp/pom.xml b/spring-ai-mcp/pom.xml deleted file mode 100644 index ab74c04b4..000000000 --- a/spring-ai-mcp/pom.xml +++ /dev/null @@ -1,53 +0,0 @@ - - - 4.0.0 - - org.springframework.experimental - mcp-parent - 0.6.0 - - spring-ai-mcp - Spring AI MCP - Spring Framework integration for Model Context Protocol (MCP), providing Spring AI function calling capabilities and Spring-friendly abstractions for MCP clients and MCP servers - https://github.com/spring-projects-experimental/spring-ai-mcp - - - https://github.com/spring-projects-experimental/spring-ai-mcp - git://github.com/spring-projects-experimental/spring-ai-mcp.git - git@github.com:spring-projects-experimental/spring-ai-mcp.git - - - - - org.springframework.ai - spring-ai-core - ${spring-ai.version} - - - org.springframework - spring-context - - - org.springframework - spring-messaging - - - - - - org.springframework.experimental - mcp - ${project.version} - - - - org.springframework - spring-core - ${springframework.version} - - - - - diff --git a/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/McpFunctionCallback.java b/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/McpFunctionCallback.java deleted file mode 100644 index e06b02625..000000000 --- a/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/McpFunctionCallback.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.spring; - -import java.util.Map; - -import org.springframework.ai.mcp.client.McpSyncClient; -import org.springframework.ai.mcp.spec.McpSchema.CallToolRequest; -import org.springframework.ai.mcp.spec.McpSchema.CallToolResult; -import org.springframework.ai.mcp.spec.McpSchema.Tool; - -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; - -/** - * @author Christian Tzolov - */ - -public class McpFunctionCallback implements FunctionCallback { - - // TODO: revisit function calling as well to handle the async case - private final McpSyncClient mcpClient; - - private final Tool tool; - - public McpFunctionCallback(McpSyncClient clientSession, Tool tool) { - this.mcpClient = clientSession; - this.tool = tool; - } - - @Override - public String getName() { - return this.tool.name(); - } - - @Override - public String getDescription() { - return this.tool.description(); - } - - @Override - public String getInputTypeSchema() { - return ModelOptionsUtils.toJsonString(this.tool.inputSchema()); - } - - @Override - public String call(String functionInput) { - Map arguments = ModelOptionsUtils.jsonToMap(functionInput); - CallToolResult response = this.mcpClient.callTool(new CallToolRequest(this.getName(), arguments)); - // Todo handle errors - return ModelOptionsUtils.toJsonString(response.content()); - } - -} diff --git a/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/ToolHelper.java b/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/ToolHelper.java deleted file mode 100644 index a3c5e44a5..000000000 --- a/spring-ai-mcp/src/main/java/org/springframework/ai/mcp/spring/ToolHelper.java +++ /dev/null @@ -1,218 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.mcp.spring; - -import java.util.List; -import java.util.Map; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.fasterxml.jackson.module.jsonSchema.JsonSchema; -import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; - -import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.mcp.server.McpServer; -import org.springframework.ai.mcp.server.McpServerFeatures; -import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.util.ClassUtils; - -/** - * Utility class that provides helper methods for working with Model Context Protocol - * (MCP) tools in a Spring AI environment. This class facilitates the integration between - * Spring AI's function callbacks and MCP's tool system. - * - *

    - * The MCP tool system enables servers to expose executable functionality to language - * models, allowing them to interact with external systems, perform computations, and take - * actions in the real world. Each tool is uniquely identified by a name and includes - * metadata describing its schema. - * - *

    - * This helper class provides methods to: - *

      - *
    • Convert Spring AI's {@link FunctionCallback} instances to MCP tool - * registrations
    • - *
    • Generate JSON schemas for tool input validation
    • - *
    - * - * @author Christian Tzolov - * @see org.springframework.ai.model.function.FunctionCallback - * @see org.springframework.ai.mcp.server.McpServer.ToolRegistration - * @see org.springframework.ai.mcp.spec.McpSchema.Tool - */ -public final class ToolHelper { - - private ToolHelper() { - } - - public static List toSyncToolRegistration( - List functionCallbacks) { - return functionCallbacks.stream().map(ToolHelper::toSyncToolRegistration).toList(); - } - - @Deprecated - public static List toToolRegistration(List functionCallbacks) { - return functionCallbacks.stream().map(ToolHelper::toToolRegistration).toList(); - } - - public static List toSyncToolRegistration( - FunctionCallback... functionCallbacks) { - return toSyncToolRegistration(List.of(functionCallbacks)); - } - - @Deprecated - public static List toToolRegistration(FunctionCallback... functionCallbacks) { - return toToolRegistration(List.of(functionCallbacks)); - } - - /** - * Converts a Spring AI FunctionCallback to an MCP SyncToolRegistration. This enables - * Spring AI functions to be exposed as MCP tools that can be discovered and invoked - * by language models. - * - *

    - * The conversion process: - *

      - *
    • Creates an MCP Tool with the function's name and input schema
    • - *
    • Wraps the function's execution in a SyncToolRegistration that handles the MCP - * protocol
    • - *
    • Provides error handling and result formatting according to MCP - * specifications
    • - *
    - * - * You can use the FunctionCallback builder to create a new instance of - * FunctionCallback using either java.util.function.Function or Method reference. - * @param functionCallback the Spring AI function callback to convert - * @return an MCP SyncToolRegistration that wraps the function callback - * @throws RuntimeException if there's an error during the function execution - */ - public static McpServerFeatures.SyncToolRegistration toSyncToolRegistration(FunctionCallback functionCallback) { - var tool = new McpSchema.Tool(functionCallback.getName(), functionCallback.getName(), - functionCallback.getInputTypeSchema()); - - return new McpServerFeatures.SyncToolRegistration(tool, request -> { - try { - String callResult = functionCallback.call(ModelOptionsUtils.toJsonString(request)); - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false); - } - catch (Exception e) { - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true); - } - }); - } - - /** - * Converts a Spring AI FunctionCallback to an MCP ToolRegistration. This enables - * Spring AI functions to be exposed as MCP tools that can be discovered and invoked - * by language models. - * - *

    - * The conversion process: - *

      - *
    • Creates an MCP Tool with the function's name and input schema
    • - *
    • Wraps the function's execution in a ToolRegistration that handles the MCP - * protocol
    • - *
    • Provides error handling and result formatting according to MCP - * specifications
    • - *
    - * - * You can use the FunctionCallback builder to create a new instance of - * FunctionCallback using either java.util.function.Function or Method reference. - * @param functionCallback the Spring AI function callback to convert - * @return an MCP ToolRegistration that wraps the function callback - * @throws RuntimeException if there's an error during the function execution - * @deprecated Use {@link #toSyncToolRegistration(FunctionCallback)}. - */ - @Deprecated - public static McpServer.ToolRegistration toToolRegistration(FunctionCallback functionCallback) { - - var tool = new McpSchema.Tool(functionCallback.getName(), functionCallback.getName(), - functionCallback.getInputTypeSchema()); - - return new McpServer.ToolRegistration(tool, request -> { - try { - String callResult = functionCallback.call(ModelOptionsUtils.toJsonString(request)); - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false); - } - catch (Exception e) { - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true); - } - }); - } - - /** - * Generates a JSON schema for a map of named classes using the default ObjectMapper. - * This schema can be used to validate tool inputs according to the MCP specification. - * @param namedClasses a map of class names to their corresponding Class objects - * @return a JSON schema string that describes the structure of the named classes - * @throws RuntimeException if schema generation fails - * @see #generateJsonSchema(Map, ObjectMapper) - */ - public static String generateJsonSchema(Map> namedClasses) { - return generateJsonSchema(namedClasses, new ObjectMapper()); - } - - /** - * Generates a JSON schema for a map of named classes using a custom ObjectMapper. The - * generated schema follows the JSON Schema Draft 2020-12 specification and describes - * the structure of the provided classes. - * - *

    - * This method: - *

      - *
    • Creates a schema that validates the structure of tool inputs
    • - *
    • Excludes ToolContext class from schema generation
    • - *
    • Uses Jackson's JsonSchemaGenerator for accurate type representation
    • - *
    - * @param namedClasses a map of class names to their corresponding Class objects - * @param mapper the ObjectMapper to use for JSON processing - * @return a JSON schema string that describes the structure of the named classes - * @throws RuntimeException if schema generation fails - */ - public static String generateJsonSchema(Map> namedClasses, ObjectMapper mapper) { - try { - JsonSchemaGenerator schemaGen = new JsonSchemaGenerator(mapper); - - ObjectNode rootNode = mapper.createObjectNode(); - rootNode.put("$schema", "https://json-schema.org/draft/2020-12/schema"); - rootNode.put("type", "object"); - ObjectNode propertiesNode = rootNode.putObject("properties"); - - for (Map.Entry> entry : namedClasses.entrySet()) { - String className = entry.getKey(); - Class clazz = entry.getValue(); - - if (ClassUtils.isAssignable(clazz, ToolContext.class)) { - // Skip the ToolContext class from the schema generation. - continue; - } - - JsonSchema schema = schemaGen.generateSchema(clazz); - JsonNode schemaNode = mapper.valueToTree(schema); - propertiesNode.set(className, schemaNode); - } - - return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(rootNode); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - -}