Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals
generate_topic_summary=True,
media_type=None,
vector_store_ids=vector_store_ids,
shield_ids=None,
)

# Get LLM client and select model
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
9 changes: 9 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QueryRequest(BaseModel):
generate_topic_summary: Whether to generate topic summary for new conversations.
media_type: The optional media type for response format (application/json or text/plain).
vector_store_ids: The optional list of specific vector store IDs to query for RAG.
shield_ids: The optional list of safety shield IDs to apply.

Example:
```python
Expand Down Expand Up @@ -166,6 +167,14 @@ class QueryRequest(BaseModel):
examples=["ocp_docs", "knowledge_base", "vector_db_1"],
)

shield_ids: Optional[list[str]] = Field(
None,
description="Optional list of safety shield IDs to apply. "
"If None, all configured shields are used. "
"If empty list, all shields are skipped.",
examples=["llama-guard", "custom-shield"],
)

# provides examples for /docs endpoint
model_config = {
"extra": "forbid",
Expand Down
38 changes: 34 additions & 4 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Utility functions for working with Llama Stack shields."""

import logging
from typing import Any, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from llama_stack_client import AsyncLlamaStackClient, BadRequestError
from llama_stack_client.types import CreateResponse

import metrics
from models.responses import NotFoundResponse
from models.responses import NotFoundResponse, UnprocessableEntityResponse
from utils.types import ShieldModerationResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,26 +63,56 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
async def run_shield_moderation(
client: AsyncLlamaStackClient,
input_text: str,
shield_ids: Optional[list[str]] = None,
) -> ShieldModerationResult:
"""
Run shield moderation on input text.

Iterates through all configured shields and runs moderation checks.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.

Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.

Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.

Raises:
HTTPException: If shield's provider_resource_id is not configured or model not found.
"""
all_shields = await client.shields.list()

# Filter shields based on shield_ids parameter
if shield_ids is not None:
if len(shield_ids) == 0:
logger.info("shield_ids=[] provided, skipping all shields")
return ShieldModerationResult(blocked=False)

shields_to_run = [s for s in all_shields if s.identifier in shield_ids]

# Log warning if requested shield not found
requested = set(shield_ids)
available = {s.identifier for s in shields_to_run}
missing = requested - available
if missing:
logger.warning("Requested shields not found: %s", missing)

# Reject if no requested shields were found (prevents accidental bypass)
if not shields_to_run:
response = UnprocessableEntityResponse(
response="Invalid shield configuration",
cause=f"Requested shield_ids not found: {sorted(missing)}",
)
raise HTTPException(**response.model_dump())
else:
shields_to_run = list(all_shields)

available_models = {model.id for model in await client.models.list()}

for shield in await client.shields.list():
for shield in shields_to_run:
if (
not shield.provider_resource_id
or shield.provider_resource_id not in available_models
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/utils/test_shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,75 @@ async def test_returns_blocked_on_bad_request_error(
assert result.shield_model == "moderation-model"
mock_metric.inc.assert_called_once()

@pytest.mark.asyncio
async def test_shield_ids_empty_list_skips_all_shields(
self, mocker: MockerFixture
) -> None:
"""Test that shield_ids=[] explicitly skips all shields (intentional bypass)."""
mock_client = mocker.Mock()
shield = mocker.Mock()
shield.identifier = "shield-1"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

result = await run_shield_moderation(mock_client, "test input", shield_ids=[])

assert result.blocked is False
mock_client.shields.list.assert_called_once()

@pytest.mark.asyncio
async def test_shield_ids_raises_exception_when_no_shields_found(
self, mocker: MockerFixture
) -> None:
"""Test shield_ids raises HTTPException when no requested shields exist."""
mock_client = mocker.Mock()
shield = mocker.Mock()
shield.identifier = "shield-1"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

with pytest.raises(HTTPException) as exc_info:
await run_shield_moderation(
mock_client, "test input", shield_ids=["typo-shield"]
)

assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore
assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore

@pytest.mark.asyncio
async def test_shield_ids_filters_to_specific_shield(
self, mocker: MockerFixture
) -> None:
"""Test that shield_ids filters to only specified shields."""
mock_client = mocker.Mock()

shield1 = mocker.Mock()
shield1.identifier = "shield-1"
shield1.provider_resource_id = "model-1"
shield2 = mocker.Mock()
shield2.identifier = "shield-2"
shield2.provider_resource_id = "model-2"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2])

model1 = mocker.Mock()
model1.id = "model-1"
mock_client.models.list = mocker.AsyncMock(return_value=[model1])

moderation_result = mocker.Mock()
moderation_result.results = [mocker.Mock(flagged=False)]
mock_client.moderations.create = mocker.AsyncMock(
return_value=moderation_result
)

result = await run_shield_moderation(
mock_client, "test input", shield_ids=["shield-1"]
)

assert result.blocked is False
assert mock_client.moderations.create.call_count == 1
mock_client.moderations.create.assert_called_with(
input="test input", model="model-1"
)


class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods
"""Tests for append_turn_to_conversation function."""
Expand Down
Loading