diff --git a/src/a2a/types.py b/src/a2a/types.py index 918a06b5e..cc1cbcb16 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -34,7 +34,7 @@ class APIKeySecurityScheme(A2ABaseModel): """ An optional description for the security scheme. """ - in_: In + in_: In = Field(..., alias='in') """ The location of the API key. """ diff --git a/src/a2a/utils/telemetry.py b/src/a2a/utils/telemetry.py index c73d2ac92..22d1e8730 100644 --- a/src/a2a/utils/telemetry.py +++ b/src/a2a/utils/telemetry.py @@ -61,6 +61,8 @@ def internal_method(self): from collections.abc import Callable from typing import TYPE_CHECKING, Any +from typing_extensions import Self + if TYPE_CHECKING: from opentelemetry.trace import SpanKind as SpanKindType @@ -86,7 +88,7 @@ class _NoOp: def __call__(self, *args: Any, **kwargs: Any) -> Any: return self - def __enter__(self) -> '_NoOp': + def __enter__(self) -> Self: return self def __exit__(self, *args: object, **kwargs: Any) -> None: diff --git a/tests/server/apps/rest/test_rest_serialization.py b/tests/server/apps/rest/test_rest_serialization.py new file mode 100644 index 000000000..b76761643 --- /dev/null +++ b/tests/server/apps/rest/test_rest_serialization.py @@ -0,0 +1,66 @@ +from unittest import mock + +import pytest + +from httpx import ASGITransport, AsyncClient + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.types import ( + APIKeySecurityScheme, + AgentCapabilities, + AgentCard, + In, + SecurityScheme, +) + + +@pytest.fixture +def agent_card_with_api_key() -> AgentCard: + api_key_scheme_data = { + 'type': 'apiKey', + 'name': 'X-API-KEY', + 'in': 'header', + } + api_key_scheme = APIKeySecurityScheme.model_validate(api_key_scheme_data) + + return AgentCard( + name='APIKeyAgent', + description='An agent that uses API Key auth.', + url='http://example.com/apikey-agent', + version='1.0.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + skills=[], + security_schemes={'api_key_auth': SecurityScheme(root=api_key_scheme)}, + security=[{'api_key_auth': []}], + ) + + +@pytest.mark.anyio +async def test_rest_agent_card_with_api_key_scheme_alias( + agent_card_with_api_key: AgentCard, +): + """Ensures REST agent card serialization uses the 'in' alias.""" + handler = mock.AsyncMock() + app_instance = A2ARESTFastAPIApplication(agent_card_with_api_key, handler) + app = app_instance.build( + agent_card_url='/.well-known/agent.json', rpc_url='' + ) + + async with AsyncClient( + transport=ASGITransport(app=app), base_url='http://test' + ) as client: + response = await client.get('/.well-known/agent.json') + + assert response.status_code == 200 + response_data = response.json() + + security_scheme_json = response_data['securitySchemes']['api_key_auth'] + assert 'in' in security_scheme_json + assert security_scheme_json['in'] == 'header' + assert 'in_' not in security_scheme_json + + parsed_card = AgentCard.model_validate(response_data) + parsed_scheme_wrapper = parsed_card.security_schemes['api_key_auth'] + assert parsed_scheme_wrapper.root.in_ == In.header diff --git a/tests/test_types.py b/tests/test_types.py index 73e6af7bb..563bf5e2f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -182,13 +182,23 @@ # --- Test Functions --- -def test_security_scheme_valid(): - scheme = SecurityScheme.model_validate(MINIMAL_AGENT_SECURITY_SCHEME) +@pytest.mark.parametrize('in_field_name', ['in', 'in_']) +def test_security_scheme_in_field_handling(in_field_name: str) -> None: + scheme_data = { + 'type': 'apiKey', + 'name': 'X-API-KEY', + in_field_name: 'header', + } + scheme = SecurityScheme.model_validate(scheme_data) assert isinstance(scheme.root, APIKeySecurityScheme) assert scheme.root.type == 'apiKey' assert scheme.root.in_ == In.header assert scheme.root.name == 'X-API-KEY' + serialized_data = scheme.model_dump(mode='json', exclude_none=True) + assert serialized_data.get('in') == 'header' + assert 'in_' not in serialized_data + def test_security_scheme_invalid(): with pytest.raises(ValidationError):