Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static io.a2a.transport.jsonrpc.context.JSONRPCContextKeys.HEADERS_KEY;
import static io.a2a.transport.jsonrpc.context.JSONRPCContextKeys.METHOD_NAME_KEY;
import static io.a2a.transport.jsonrpc.context.JSONRPCContextKeys.TENANT_KEY;
import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS;
Expand Down Expand Up @@ -101,7 +102,7 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
Multi<? extends A2AResponse<?>> streamingResponse = null;
A2AErrorResponse error = null;
try {
A2ARequest<?> request = JSONRPCUtils.parseRequestBody(body);
A2ARequest<?> request = JSONRPCUtils.parseRequestBody(body, extractTenant(rc));
context.getState().put(METHOD_NAME_KEY, request.getMethod());
if (request instanceof NonStreamingJSONRPCRequest nonStreamingRequest) {
nonStreamingResponse = processNonStreamingRequest(nonStreamingRequest, context);
Expand Down Expand Up @@ -213,7 +214,6 @@ static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) {
}

private ServerCallContext createCallContext(RoutingContext rc) {

if (callContextFactory.isUnsatisfied()) {
User user;
if (rc.user() == null) {
Expand All @@ -240,6 +240,7 @@ public String getUsername() {
Set<String> headerNames = rc.request().headers().names();
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
state.put(HEADERS_KEY, headers);
state.put(TENANT_KEY, extractTenant(rc));

// Extract requested protocol version from X-A2A-Version header
String requestedVersion = rc.request().getHeader(A2AHeaders.X_A2A_VERSION);
Expand All @@ -255,6 +256,20 @@ public String getUsername() {
}
}

private String extractTenant(RoutingContext rc) {
String tenantPath = rc.normalizedPath();
if (tenantPath == null || tenantPath.isBlank()) {
return "";
}
if (tenantPath.startsWith("/")) {
tenantPath = tenantPath.substring(1);
}
if(tenantPath.endsWith("/")) {
tenantPath = tenantPath.substring(0, tenantPath.length() -1);
}
return tenantPath;
}

private static String serializeResponse(A2AResponse<?> response) {
// For error responses, use Jackson serialization (errors are standardized)
if (response instanceof A2AErrorResponse error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static io.a2a.spec.A2AMethods.SEND_STREAMING_MESSAGE_METHOD;
import static io.a2a.spec.AgentCard.CURRENT_PROTOCOL_VERSION;
import static io.a2a.transport.jsonrpc.context.JSONRPCContextKeys.METHOD_NAME_KEY;
import static io.a2a.transport.jsonrpc.context.JSONRPCContextKeys.TENANT_KEY;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -108,6 +109,7 @@ public void setUp() {
when(mockRoutingContext.user()).thenReturn(null);
when(mockRequest.headers()).thenReturn(mockHeaders);
when(mockRoutingContext.body()).thenReturn(mockRequestBody);
when(mockRoutingContext.normalizedPath()).thenReturn("/");

// Chain the response methods properly
when(mockHttpResponse.setStatusCode(any(Integer.class))).thenReturn(mockHttpResponse);
Expand Down Expand Up @@ -520,6 +522,202 @@ public void testGetExtendedCard_MethodNameSetInContext() {
assertEquals(GET_EXTENDED_AGENT_CARD_METHOD, capturedContext.getState().get(METHOD_NAME_KEY));
}

@Test
public void testTenantExtraction_MultiSegmentPath() {
// Arrange - simulate request to /test/titi
when(mockRoutingContext.normalizedPath()).thenReturn("/test/titi");
String jsonRpcRequest = """
{
"jsonrpc": "2.0",
"id": "cd4c76de-d54c-436c-8b9f-4c2703648d64",
"method": "GetTask",
"params": {
"name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64",
"historyLength": 10
}
}""";
when(mockRequestBody.asString()).thenReturn(jsonRpcRequest);

Task responseTask = Task.builder()
.id("de38c76d-d54c-436c-8b9f-4c2703648d64")
.contextId("context-1234")
.status(new TaskStatus(TaskState.SUBMITTED))
.build();
GetTaskResponse realResponse = new GetTaskResponse("1", responseTask);
when(mockJsonRpcHandler.onGetTask(any(GetTaskRequest.class), any(ServerCallContext.class)))
.thenReturn(realResponse);

ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);

// Act
routes.invokeJSONRPCHandler(jsonRpcRequest, mockRoutingContext);

// Assert
verify(mockJsonRpcHandler).onGetTask(any(GetTaskRequest.class), contextCaptor.capture());
ServerCallContext capturedContext = contextCaptor.getValue();
assertNotNull(capturedContext);
assertEquals("test/titi", capturedContext.getState().get(TENANT_KEY));
}

@Test
public void testTenantExtraction_RootPath() {
// Arrange - simulate request to /
when(mockRoutingContext.normalizedPath()).thenReturn("/");
String jsonRpcRequest = """
{
"jsonrpc": "2.0",
"id": "cd4c76de-d54c-436c-8b9f-4c2703648d64",
"method": "GetTask",
"params": {
"name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64",
"historyLength": 10
}
}""";
when(mockRequestBody.asString()).thenReturn(jsonRpcRequest);

Task responseTask = Task.builder()
.id("de38c76d-d54c-436c-8b9f-4c2703648d64")
.contextId("context-1234")
.status(new TaskStatus(TaskState.SUBMITTED))
.build();
GetTaskResponse realResponse = new GetTaskResponse("1", responseTask);
when(mockJsonRpcHandler.onGetTask(any(GetTaskRequest.class), any(ServerCallContext.class)))
.thenReturn(realResponse);

ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);

// Act
routes.invokeJSONRPCHandler(jsonRpcRequest, mockRoutingContext);

// Assert
verify(mockJsonRpcHandler).onGetTask(any(GetTaskRequest.class), contextCaptor.capture());
ServerCallContext capturedContext = contextCaptor.getValue();
assertNotNull(capturedContext);
assertEquals("", capturedContext.getState().get(TENANT_KEY));
}

@Test
public void testTenantExtraction_SingleSegmentPath() {
// Arrange - simulate request to /tenant1
when(mockRoutingContext.normalizedPath()).thenReturn("/tenant1");
String jsonRpcRequest = """
{
"jsonrpc": "2.0",
"id": "cd4c76de-d54c-436c-8b9f-4c2703648d64",
"method": "GetTask",
"params": {
"name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64",
"historyLength": 10
}
}""";
when(mockRequestBody.asString()).thenReturn(jsonRpcRequest);

Task responseTask = Task.builder()
.id("de38c76d-d54c-436c-8b9f-4c2703648d64")
.contextId("context-1234")
.status(new TaskStatus(TaskState.SUBMITTED))
.build();
GetTaskResponse realResponse = new GetTaskResponse("1", responseTask);
when(mockJsonRpcHandler.onGetTask(any(GetTaskRequest.class), any(ServerCallContext.class)))
.thenReturn(realResponse);

ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);

// Act
routes.invokeJSONRPCHandler(jsonRpcRequest, mockRoutingContext);

// Assert
verify(mockJsonRpcHandler).onGetTask(any(GetTaskRequest.class), contextCaptor.capture());
ServerCallContext capturedContext = contextCaptor.getValue();
assertNotNull(capturedContext);
assertEquals("tenant1", capturedContext.getState().get(TENANT_KEY));
}

@Test
public void testTenantExtraction_ThreeSegmentPath() {
// Arrange - simulate request to /tenant1/api/v1
when(mockRoutingContext.normalizedPath()).thenReturn("/tenant1/api/v1");
String jsonRpcRequest = """
{
"jsonrpc": "2.0",
"id": "cd4c76de-d54c-436c-8b9f-4c2703648d64",
"method": "GetTask",
"params": {
"name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64",
"historyLength": 10
}
}""";
when(mockRequestBody.asString()).thenReturn(jsonRpcRequest);

Task responseTask = Task.builder()
.id("de38c76d-d54c-436c-8b9f-4c2703648d64")
.contextId("context-1234")
.status(new TaskStatus(TaskState.SUBMITTED))
.build();
GetTaskResponse realResponse = new GetTaskResponse("1", responseTask);
when(mockJsonRpcHandler.onGetTask(any(GetTaskRequest.class), any(ServerCallContext.class)))
.thenReturn(realResponse);

ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);

// Act
routes.invokeJSONRPCHandler(jsonRpcRequest, mockRoutingContext);

// Assert
verify(mockJsonRpcHandler).onGetTask(any(GetTaskRequest.class), contextCaptor.capture());
ServerCallContext capturedContext = contextCaptor.getValue();
assertNotNull(capturedContext);
assertEquals("tenant1/api/v1", capturedContext.getState().get(TENANT_KEY));
}

@Test
public void testTenantExtraction_StreamingRequest() {
// Arrange - simulate streaming request to /myTenant/api
when(mockRoutingContext.normalizedPath()).thenReturn("/myTenant/api");
String jsonRpcRequest = """
{
"jsonrpc": "2.0",
"id": "cd4c76de-d54c-436c-8b9f-4c2703648d64",
"method": "SendStreamingMessage",
"params": {
"message": {
"messageId": "message-1234",
"contextId": "context-1234",
"role": "ROLE_USER",
"parts": [
{
"text": "tell me a joke"
}
],
"metadata": {}
},
"configuration": {
"acceptedOutputModes": ["text"],
"blocking": true
},
"metadata": {}
}
}""";
when(mockRequestBody.asString()).thenReturn(jsonRpcRequest);

@SuppressWarnings("unchecked")
Flow.Publisher<SendStreamingMessageResponse> mockPublisher = mock(Flow.Publisher.class);
when(mockJsonRpcHandler.onMessageSendStream(any(SendStreamingMessageRequest.class),
any(ServerCallContext.class))).thenReturn(mockPublisher);

ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);

// Act
routes.invokeJSONRPCHandler(jsonRpcRequest, mockRoutingContext);

// Assert
verify(mockJsonRpcHandler).onMessageSendStream(any(SendStreamingMessageRequest.class),
contextCaptor.capture());
ServerCallContext capturedContext = contextCaptor.getValue();
assertNotNull(capturedContext);
assertEquals("myTenant/api", capturedContext.getState().get(TENANT_KEY));
}

/**
* Helper method to set a field via reflection for testing purposes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
import static io.a2a.spec.A2AMethods.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD;
import static io.a2a.spec.A2AMethods.SUBSCRIBE_TO_TASK_METHOD;

import static io.a2a.transport.rest.context.RestContextKeys.TENANT_KEY;

@Singleton
@Authenticated
public class A2AServerRoutes {
Expand Down Expand Up @@ -435,6 +437,7 @@ public String getUsername() {
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
state.put(HEADERS_KEY, headers);
state.put(METHOD_NAME_KEY, jsonRpcMethodName);
state.put(TENANT_KEY, extractTenant(rc));

// Extract requested protocol version from X-A2A-Version header
String requestedVersion = rc.request().getHeader(A2AHeaders.X_A2A_VERSION);
Expand Down
Loading