Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,24 @@

package io.modelcontextprotocol.server;

import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.BiFunction;

import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.json.schema.JsonSchemaValidator;
import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.spec.McpSchema.*;
import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion;
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
import io.modelcontextprotocol.spec.McpSchema.ResourceReference;
import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerTransportProviderBase;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.McpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.Utils;
import io.modelcontextprotocol.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.BiFunction;

import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND;

/**
Expand Down Expand Up @@ -98,6 +76,8 @@ public class McpAsyncServer {

private final JsonSchemaValidator jsonSchemaValidator;

private final boolean validateToolInputs;

private final McpSchema.ServerCapabilities serverCapabilities;

private final McpSchema.Implementation serverInfo;
Expand Down Expand Up @@ -129,7 +109,8 @@ public class McpAsyncServer {
*/
McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper,
McpServerFeatures.Async features, Duration requestTimeout,
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) {
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator,
boolean validateToolInputs) {
this.mcpTransportProvider = mcpTransportProvider;
this.jsonMapper = jsonMapper;
this.serverInfo = features.serverInfo();
Expand All @@ -142,6 +123,7 @@ public class McpAsyncServer {
this.completions.putAll(features.completions());
this.uriTemplateManagerFactory = uriTemplateManagerFactory;
this.jsonSchemaValidator = jsonSchemaValidator;
this.validateToolInputs = validateToolInputs;

Map<String, McpRequestHandler<?>> requestHandlers = prepareRequestHandlers();
Map<String, McpNotificationHandler> notificationHandlers = prepareNotificationHandlers(features);
Expand All @@ -157,7 +139,8 @@ public class McpAsyncServer {

McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper,
McpServerFeatures.Async features, Duration requestTimeout,
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) {
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator,
boolean validateToolInputs) {
this.mcpTransportProvider = mcpTransportProvider;
this.jsonMapper = jsonMapper;
this.serverInfo = features.serverInfo();
Expand All @@ -170,6 +153,7 @@ public class McpAsyncServer {
this.completions.putAll(features.completions());
this.uriTemplateManagerFactory = uriTemplateManagerFactory;
this.jsonSchemaValidator = jsonSchemaValidator;
this.validateToolInputs = validateToolInputs;

Map<String, McpRequestHandler<?>> requestHandlers = prepareRequestHandlers();
Map<String, McpNotificationHandler> notificationHandlers = prepareNotificationHandlers(features);
Expand Down Expand Up @@ -543,6 +527,13 @@ private McpRequestHandler<CallToolResult> toolsCallRequestHandler() {
.build());
}

McpSchema.Tool tool = toolSpecification.get().tool();
CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(),
this.validateToolInputs, this.jsonMapper, this.jsonSchemaValidator);
if (validationError != null) {
return Mono.just(validationError);
}

return toolSpecification.get().callHandler().apply(exchange, callToolRequest);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.McpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.ToolInputValidator;
import io.modelcontextprotocol.util.ToolNameValidator;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -243,7 +244,7 @@ public McpAsyncServer build() {
: McpJsonDefaults.getSchemaValidator();

return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper,
features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator);
features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs);
}

}
Expand All @@ -269,7 +270,7 @@ public McpAsyncServer build() {
var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator
: McpJsonDefaults.getSchemaValidator();
return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper,
features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator);
features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs);
}

}
Expand All @@ -293,6 +294,8 @@ abstract class AsyncSpecification<S extends AsyncSpecification<S>> {

boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault();

boolean validateToolInputs = ToolInputValidator.isEnabledByDefault();

/**
* The Model Context Protocol (MCP) allows servers to expose tools that can be
* invoked by language models. Tools enable models to interact with external
Expand Down Expand Up @@ -421,6 +424,18 @@ public AsyncSpecification<S> strictToolNameValidation(boolean strict) {
return this;
}

/**
* Sets whether to validate tool inputs against the tool's input schema. When set,
* this takes priority over the system property
* {@code io.modelcontextprotocol.validateToolInputs}.
* @param validate true to validate inputs and return error on validation failure
* @return This builder instance for method chaining
*/
public AsyncSpecification<S> validateToolInputs(boolean validate) {
this.validateToolInputs = validate;
return this;
}

/**
* Sets the server capabilities that will be advertised to clients during
* connection initialization. Capabilities define what features the server
Expand Down Expand Up @@ -818,7 +833,8 @@ public McpSyncServer build() {
var asyncServer = new McpAsyncServer(transportProvider,
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout,
uriTemplateManagerFactory,
jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator());
jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(),
validateToolInputs);
return new McpSyncServer(asyncServer, this.immediateExecution);
}

Expand Down Expand Up @@ -849,7 +865,7 @@ public McpSyncServer build() {
: McpJsonDefaults.getSchemaValidator();
var asyncServer = new McpAsyncServer(transportProvider,
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, this.requestTimeout,
this.uriTemplateManagerFactory, jsonSchemaValidator);
this.uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs);
return new McpSyncServer(asyncServer, this.immediateExecution);
}

Expand All @@ -872,6 +888,8 @@ abstract class SyncSpecification<S extends SyncSpecification<S>> {

boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault();

boolean validateToolInputs = ToolInputValidator.isEnabledByDefault();

/**
* The Model Context Protocol (MCP) allows servers to expose tools that can be
* invoked by language models. Tools enable models to interact with external
Expand Down Expand Up @@ -1004,6 +1022,18 @@ public SyncSpecification<S> strictToolNameValidation(boolean strict) {
return this;
}

/**
* Sets whether to validate tool inputs against the tool's input schema. When set,
* this takes priority over the system property
* {@code io.modelcontextprotocol.validateToolInputs}.
* @param validate true to validate inputs and return error on validation failure
* @return This builder instance for method chaining
*/
public SyncSpecification<S> validateToolInputs(boolean validate) {
this.validateToolInputs = validate;
return this;
}

/**
* Sets the server capabilities that will be advertised to clients during
* connection initialization. Capabilities define what features the server
Expand Down Expand Up @@ -1401,6 +1431,8 @@ class StatelessAsyncSpecification {

boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault();

boolean validateToolInputs = ToolInputValidator.isEnabledByDefault();

/**
* The Model Context Protocol (MCP) allows servers to expose tools that can be
* invoked by language models. Tools enable models to interact with external
Expand Down Expand Up @@ -1530,6 +1562,18 @@ public StatelessAsyncSpecification strictToolNameValidation(boolean strict) {
return this;
}

/**
* Sets whether to validate tool inputs against the tool's input schema. When set,
* this takes priority over the system property
* {@code io.modelcontextprotocol.validateToolInputs}.
* @param validate true to validate inputs and return error on validation failure
* @return This builder instance for method chaining
*/
public StatelessAsyncSpecification validateToolInputs(boolean validate) {
this.validateToolInputs = validate;
return this;
}

/**
* Sets the server capabilities that will be advertised to clients during
* connection initialization. Capabilities define what features the server
Expand Down Expand Up @@ -1859,7 +1903,8 @@ public McpStatelessAsyncServer build() {
this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions);
return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper,
features, requestTimeout, uriTemplateManagerFactory,
jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator());
jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(),
validateToolInputs);
}

}
Expand All @@ -1884,6 +1929,8 @@ class StatelessSyncSpecification {

boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault();

boolean validateToolInputs = ToolInputValidator.isEnabledByDefault();

/**
* The Model Context Protocol (MCP) allows servers to expose tools that can be
* invoked by language models. Tools enable models to interact with external
Expand Down Expand Up @@ -2013,6 +2060,18 @@ public StatelessSyncSpecification strictToolNameValidation(boolean strict) {
return this;
}

/**
* Sets whether to validate tool inputs against the tool's input schema. When set,
* this takes priority over the system property
* {@code io.modelcontextprotocol.validateToolInputs}.
* @param validate true to validate inputs and return error on validation failure
* @return This builder instance for method chaining
*/
public StatelessSyncSpecification validateToolInputs(boolean validate) {
this.validateToolInputs = validate;
return this;
}

/**
* Sets the server capabilities that will be advertised to clients during
* connection initialization. Capabilities define what features the server
Expand Down Expand Up @@ -2360,7 +2419,8 @@ public McpStatelessSyncServer build() {
var asyncServer = new McpStatelessAsyncServer(transport,
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout,
uriTemplateManagerFactory,
this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator());
this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(),
validateToolInputs);
return new McpStatelessSyncServer(asyncServer, this.immediateExecution);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.McpUriTemplateManagerFactory;
import io.modelcontextprotocol.util.ToolInputValidator;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -77,9 +78,12 @@ public class McpStatelessAsyncServer {

private final JsonSchemaValidator jsonSchemaValidator;

private final boolean validateToolInputs;

McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper,
McpStatelessServerFeatures.Async features, Duration requestTimeout,
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) {
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator,
boolean validateToolInputs) {
this.mcpTransportProvider = mcpTransport;
this.jsonMapper = jsonMapper;
this.serverInfo = features.serverInfo();
Expand All @@ -92,6 +96,7 @@ public class McpStatelessAsyncServer {
this.completions.putAll(features.completions());
this.uriTemplateManagerFactory = uriTemplateManagerFactory;
this.jsonSchemaValidator = jsonSchemaValidator;
this.validateToolInputs = validateToolInputs;

Map<String, McpStatelessRequestHandler<?>> requestHandlers = new HashMap<>();

Expand Down Expand Up @@ -409,6 +414,13 @@ private McpStatelessRequestHandler<CallToolResult> toolsCallRequestHandler() {
.build());
}

McpSchema.Tool tool = toolSpecification.get().tool();
CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(),
this.validateToolInputs, this.jsonMapper, this.jsonSchemaValidator);
if (validationError != null) {
return Mono.just(validationError);
}

return toolSpecification.get().callHandler().apply(ctx, callToolRequest);
};
}
Expand Down
Loading
Loading