GLMSearch/GLMSearch/McpEndpointRouteBuilderExtensions.cs
2025-05-06 17:56:52 +08:00

157 lines
7.7 KiB
C#

using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
namespace WebSSE
{
/// <summary>
/// Extension methods for <see cref="IEndpointRouteBuilder"/> to add MCP endpoints.
/// https://github.com/modelcontextprotocol/csharp-sdk/tree/main
/// </summary>
public static class McpEndpointRouteBuilderExtensions
{
/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern = "",
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);
/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
RoutePattern pattern,
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
{
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();
var routeGroup = endpoints.MapGroup(pattern);
routeGroup.MapGet("/sse", async context =>
{
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
// which defaults to 30 seconds.
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
var cancellationToken = sseCts.Token;
var response = context.Response;
response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache,no-store";
// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
if (!_sessions.TryAdd(sessionId, transport))
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
var options = optionsSnapshot.Value;
if (configureOptionsAsync is not null)
{
options = optionsFactory.Create(Options.DefaultName);
await configureOptionsAsync.Invoke(context, options, cancellationToken);
}
try
{
var transportTask = transport.RunAsync(cancellationToken);
try
{
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
context.Features.Set(mcpServer);
runSessionAsync ??= RunSession;
await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
}
finally
{
_sessions.TryRemove(sessionId, out _);
}
});
routeGroup.MapPost("/message", async context =>
{
if (!context.Request.Query.TryGetValue("sessionId", out var sessionId))
{
await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context);
return;
}
if (!_sessions.TryGetValue(sessionId.ToString(), out var transport))
{
await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
return;
}
var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
return;
}
await transport.OnMessageReceivedAsync(message, context.RequestAborted);
context.Response.StatusCode = StatusCodes.Status202Accepted;
await context.Response.WriteAsync("Accepted");
});
return routeGroup;
}
private static Task RunSession(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)
=> session.RunAsync(requestAborted);
private static string MakeNewSessionId()
{
// 128 bits
Span<byte> buffer = stackalloc byte[16];
RandomNumberGenerator.Fill(buffer);
return WebEncoders.Base64UrlEncode(buffer);
}
}
}