Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,23 +478,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
await response(scope, receive, send)
return

# Check if this is an initialization request
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)

# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
is_initialization_request = False
if isinstance(message, JSONRPCRequest) and message.method == "initialize":
is_initialization_request = True
if not await self._validate_initialization_request(message, request, send):
return
elif not await self._validate_request_headers(request, send): # pragma: no cover
return

Expand Down Expand Up @@ -867,6 +855,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool

return True

async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool:
if not await self._validate_initialization_protocol_version(message, request, send):
return False

if not self.mcp_session_id:
return True

request_session_id = self._get_session_id(request)
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(request.scope, request.receive, send)
return False

return True

async def _validate_initialization_protocol_version(
self, message: JSONRPCRequest, request: Request, send: Send
) -> bool:
header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None
if (
header_protocol_version is not None
and body_protocol_version is not None
and header_protocol_version != body_protocol_version
):
response = self._create_error_response(
f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion",
HTTPStatus.BAD_REQUEST,
INVALID_REQUEST,
)
await response(request.scope, request.receive, send)
return False

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
"""Replays events that would have been sent after the specified event ID.

Expand Down
99 changes: 99 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount
from starlette.types import Message

from mcp import MCPError, types
from mcp.client.session import ClientSession
Expand Down Expand Up @@ -1718,6 +1719,104 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv
assert response.status_code == 200


@pytest.mark.parametrize(
("header_version", "body_version"),
[
("2025-03-26", "2025-06-18"),
("2025-06-18", "2025-03-26"),
],
)
def test_server_rejects_initialize_protocol_version_mismatch(
basic_server: None, basic_server_url: str, header_version: str, body_version: str
):
"""Test initialize rejects conflicting protocol versions in header and body."""
init_request: dict[str, Any] = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": body_version,
"capabilities": {},
},
"id": "init-1",
}

response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_PROTOCOL_VERSION_HEADER: header_version,
},
json=init_request,
)

assert response.status_code == 400
assert MCP_PROTOCOL_VERSION_HEADER in response.text
assert "protocolVersion" in response.text


@pytest.mark.anyio
async def test_server_rejects_initialize_protocol_version_mismatch_in_process():
transport = StreamableHTTPServerTransport("/mcp")
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
transport._read_stream_writer = write_stream
body = json.dumps(
{
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-06-18",
"capabilities": {},
},
"id": "init-1",
}
).encode()
sent: list[Message] = []
received_body = False

async def receive() -> Message:
nonlocal received_body
if received_body:
return {"type": "http.disconnect"}

received_body = True
return {"type": "http.request", "body": body, "more_body": False}

async def send(message: Message) -> None:
sent.append(message)

scope = {
"type": "http",
"asgi": {"version": "3.0"},
"method": "POST",
"path": "/mcp",
"raw_path": b"/mcp",
"query_string": b"",
"headers": [
(b"accept", b"application/json, text/event-stream"),
(b"content-type", b"application/json"),
(MCP_PROTOCOL_VERSION_HEADER.encode(), b"2025-03-26"),
],
"client": ("testclient", 50000),
"server": ("testserver", 80),
"scheme": "http",
}

try:
await transport.handle_request(scope, receive, send)
assert await receive() == {"type": "http.disconnect"}
finally:
await write_stream.aclose()
await read_stream.aclose()

assert any(message["type"] == "http.response.start" and message["status"] == 400 for message in sent)
response_body = b"".join(message.get("body", b"") for message in sent if message["type"] == "http.response.body")
assert MCP_PROTOCOL_VERSION_HEADER.encode() in response_body
assert b"protocolVersion" in response_body


def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str):
"""Test server accepts requests without protocol version header."""
# First initialize a session to get a valid session ID
Expand Down
Loading