diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 7e3fc015..b0fc9d58 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals generate_topic_summary=True, media_type=None, vector_store_ids=vector_store_ids, + shield_ids=None, ) # Get LLM client and select model diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index ecc39b07..dfd5f766 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -401,7 +401,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ) # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) if moderation_result.blocked: violation_message = moderation_result.message or "" await append_turn_to_conversation( diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index e1c02ca4..787b4565 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -451,7 +451,9 @@ async def retrieve_response( # pylint: disable=too-many-locals ) # Run shield moderation before calling LLM - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) if moderation_result.blocked: violation_message = moderation_result.message or "" await append_turn_to_conversation( diff --git a/src/models/requests.py b/src/models/requests.py index 18e5b4b6..ccef6a74 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -83,6 +83,7 @@ class QueryRequest(BaseModel): generate_topic_summary: Whether to generate topic summary for new conversations. media_type: The optional media type for response format (application/json or text/plain). vector_store_ids: The optional list of specific vector store IDs to query for RAG. + shield_ids: The optional list of safety shield IDs to apply. Example: ```python @@ -166,6 +167,14 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + shield_ids: Optional[list[str]] = Field( + None, + description="Optional list of safety shield IDs to apply. " + "If None, all configured shields are used. " + "If empty list, all shields are skipped.", + examples=["llama-guard", "custom-shield"], + ) + # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/shields.py b/src/utils/shields.py index 065cc96e..9e4a929c 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -1,14 +1,14 @@ """Utility functions for working with Llama Stack shields.""" import logging -from typing import Any, cast +from typing import Any, Optional, cast from fastapi import HTTPException from llama_stack_client import AsyncLlamaStackClient, BadRequestError from llama_stack_client.types import CreateResponse import metrics -from models.responses import NotFoundResponse +from models.responses import NotFoundResponse, UnprocessableEntityResponse from utils.types import ShieldModerationResult logger = logging.getLogger(__name__) @@ -63,16 +63,19 @@ def detect_shield_violations(output_items: list[Any]) -> bool: async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ Run shield moderation on input text. - Iterates through all configured shields and runs moderation checks. + Iterates through configured shields and runs moderation checks. Raises HTTPException if shield model is not found. Parameters: client: The Llama Stack client. input_text: The text to moderate. + shield_ids: Optional list of shield IDs to use. If None, uses all shields. + If empty list, skips all shields. Returns: ShieldModerationResult: Result indicating if content was blocked and the message. @@ -80,9 +83,36 @@ async def run_shield_moderation( Raises: HTTPException: If shield's provider_resource_id is not configured or model not found. """ + all_shields = await client.shields.list() + + # Filter shields based on shield_ids parameter + if shield_ids is not None: + if len(shield_ids) == 0: + logger.info("shield_ids=[] provided, skipping all shields") + return ShieldModerationResult(blocked=False) + + shields_to_run = [s for s in all_shields if s.identifier in shield_ids] + + # Log warning if requested shield not found + requested = set(shield_ids) + available = {s.identifier for s in shields_to_run} + missing = requested - available + if missing: + logger.warning("Requested shields not found: %s", missing) + + # Reject if no requested shields were found (prevents accidental bypass) + if not shields_to_run: + response = UnprocessableEntityResponse( + response="Invalid shield configuration", + cause=f"Requested shield_ids not found: {sorted(missing)}", + ) + raise HTTPException(**response.model_dump()) + else: + shields_to_run = list(all_shields) + available_models = {model.id for model in await client.models.list()} - for shield in await client.shields.list(): + for shield in shields_to_run: if ( not shield.provider_resource_id or shield.provider_resource_id not in available_models diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index adf3fe8b..b33a4289 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -312,6 +312,75 @@ async def test_returns_blocked_on_bad_request_error( assert result.shield_model == "moderation-model" mock_metric.inc.assert_called_once() + @pytest.mark.asyncio + async def test_shield_ids_empty_list_skips_all_shields( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids=[] explicitly skips all shields (intentional bypass).""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + + assert result.blocked is False + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_shield_ids_raises_exception_when_no_shields_found( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids raises HTTPException when no requested shields exist.""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation( + mock_client, "test input", shield_ids=["typo-shield"] + ) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore + assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_shield_ids_filters_to_specific_shield( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids filters to only specified shields.""" + mock_client = mocker.Mock() + + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield1.provider_resource_id = "model-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + shield2.provider_resource_id = "model-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + model1 = mocker.Mock() + model1.id = "model-1" + mock_client.models.list = mocker.AsyncMock(return_value=[model1]) + + moderation_result = mocker.Mock() + moderation_result.results = [mocker.Mock(flagged=False)] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation( + mock_client, "test input", shield_ids=["shield-1"] + ) + + assert result.blocked is False + assert mock_client.moderations.create.call_count == 1 + mock_client.moderations.create.assert_called_with( + input="test input", model="model-1" + ) + class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods """Tests for append_turn_to_conversation function."""