Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ API Reference
:toctree: _autosummary/

class_name_to_snake_case
ConverterIdentifier
Identifiable
Identifier
IdentifierT
Expand Down
20 changes: 15 additions & 5 deletions pyrit/exceptions/exception_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from contextvars import ContextVar
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from pyrit.identifiers import Identifier


class ComponentRole(Enum):
Expand Down Expand Up @@ -191,7 +193,7 @@ def execution_context(
component_role: ComponentRole,
attack_strategy_name: Optional[str] = None,
attack_identifier: Optional[Dict[str, Any]] = None,
component_identifier: Optional[Dict[str, Any]] = None,
component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None,
objective_target_conversation_id: Optional[str] = None,
objective: Optional[str] = None,
) -> ExecutionContextManager:
Expand All @@ -203,6 +205,7 @@ def execution_context(
attack_strategy_name: The name of the attack strategy class.
attack_identifier: The identifier from attack.get_identifier().
component_identifier: The identifier from component.get_identifier().
Can be an Identifier object or a dict (legacy format).
objective_target_conversation_id: The objective target conversation ID if available.
objective: The attack objective if available.

Expand All @@ -212,15 +215,22 @@ def execution_context(
# Extract endpoint and component_name from component_identifier if available
endpoint = None
component_name = None
component_id_dict: Optional[Dict[str, Any]] = None
if component_identifier:
endpoint = component_identifier.get("endpoint")
component_name = component_identifier.get("__type__")
if isinstance(component_identifier, Identifier):
endpoint = getattr(component_identifier, "endpoint", None)
component_name = component_identifier.class_name
component_id_dict = component_identifier.to_dict()
else:
endpoint = component_identifier.get("endpoint")
component_name = component_identifier.get("__type__")
component_id_dict = component_identifier

context = ExecutionContext(
component_role=component_role,
attack_strategy_name=attack_strategy_name,
attack_identifier=attack_identifier,
component_identifier=component_identifier,
component_identifier=component_id_dict,
objective_target_conversation_id=objective_target_conversation_id,
endpoint=endpoint,
component_name=component_name,
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ async def _score_combined_value_async(
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier().to_dict(),
component_identifier=self._objective_scorer.get_identifier(),
objective=objective,
):
scores = await self._objective_scorer.score_text_async(text=combined_value, objective=objective)
Expand Down
4 changes: 2 additions & 2 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective:
component_role=ComponentRole.REFUSAL_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._refusal_scorer.get_identifier().to_dict(),
component_identifier=self._refusal_scorer.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
):
Expand Down Expand Up @@ -666,7 +666,7 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier().to_dict(),
component_identifier=self._objective_scorer.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
):
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/multi_prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) -
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier().to_dict() if self._objective_scorer else None,
component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None,
objective=objective,
):
scoring_results = await Scorer.score_response_async(
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) -
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier().to_dict(),
component_identifier=self._objective_scorer.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
):
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self._attack_strategy_name,
attack_identifier=self._attack_id,
component_identifier=self._objective_scorer.get_identifier().to_dict(),
component_identifier=self._objective_scorer.get_identifier(),
objective_target_conversation_id=self.objective_target_conversation_id,
objective=objective,
):
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/single_turn/prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ async def _evaluate_response_async(
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier().to_dict() if self._objective_scorer else None,
component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None,
objective=objective,
):
scoring_results = await Scorer.score_response_async(
Expand Down
2 changes: 2 additions & 0 deletions pyrit/identifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class_name_to_snake_case,
snake_case_to_class_name,
)
from pyrit.identifiers.converter_identifier import ConverterIdentifier
from pyrit.identifiers.identifiable import Identifiable, IdentifierT, LegacyIdentifiable
from pyrit.identifiers.identifier import (
Identifier,
Expand All @@ -16,6 +17,7 @@

__all__ = [
"class_name_to_snake_case",
"ConverterIdentifier",
"Identifiable",
"Identifier",
"IdentifierT",
Expand Down
77 changes: 77 additions & 0 deletions pyrit/identifiers/converter_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type, cast

from pyrit.identifiers.identifier import Identifier


@dataclass(frozen=True)
class ConverterIdentifier(Identifier):
"""
Identifier for PromptConverter instances.

This frozen dataclass extends Identifier with converter-specific fields.
It provides a structured way to identify and track converters used in
prompt transformations.
"""

supported_input_types: Tuple[str, ...] = field(kw_only=True)
"""The input data types supported by this converter (e.g., ('text',), ('image', 'text'))."""

supported_output_types: Tuple[str, ...] = field(kw_only=True)
"""The output data types produced by this converter."""

sub_identifier: Optional[List["ConverterIdentifier"]] = None
"""List of sub-converter identifiers for composite converters like ConverterPipeline."""

target_info: Optional[Dict[str, Any]] = None
"""Information about the prompt target used by the converter (for LLM-based converters)."""

converter_specific_params: Optional[Dict[str, Any]] = None
"""Additional converter-specific parameters."""

@classmethod
def from_dict(cls: Type["ConverterIdentifier"], data: dict[str, Any]) -> "ConverterIdentifier":
"""
Create a ConverterIdentifier from a dictionary (e.g., retrieved from database).

Extends the base Identifier.from_dict() to recursively reconstruct
nested ConverterIdentifier objects in sub_identifier.

Args:
data: The dictionary representation.

Returns:
ConverterIdentifier: A new ConverterIdentifier instance.
"""
# Create a mutable copy
data = dict(data)

# Recursively reconstruct sub_identifier if present
if "sub_identifier" in data and data["sub_identifier"] is not None:
data["sub_identifier"] = [
ConverterIdentifier.from_dict(sub) if isinstance(sub, dict) else sub for sub in data["sub_identifier"]
]

# Convert supported_input_types and supported_output_types from list to tuple if needed
if "supported_input_types" in data and data["supported_input_types"] is not None:
if isinstance(data["supported_input_types"], list):
data["supported_input_types"] = tuple(data["supported_input_types"])
else:
# Provide default for legacy dicts that don't have this field
data["supported_input_types"] = ()

if "supported_output_types" in data and data["supported_output_types"] is not None:
if isinstance(data["supported_output_types"], list):
data["supported_output_types"] = tuple(data["supported_output_types"])
else:
# Provide default for legacy dicts that don't have this field
data["supported_output_types"] = ()

# Delegate to parent class for standard processing
result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined]
return cast(ConverterIdentifier, result)
27 changes: 15 additions & 12 deletions pyrit/identifiers/identifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

from pyrit.identifiers.identifier import Identifier

Expand Down Expand Up @@ -37,29 +37,32 @@ class Identifiable(ABC, Generic[IdentifierT]):
Generic over IdentifierT, allowing subclasses to specify their exact
identifier type for strong typing support.

Subclasses must:
1. Implement `_build_identifier()` to construct their specific identifier
2. Implement `get_identifier()` to return the typed identifier (can use lazy building)
Subclasses must implement `_build_identifier()` to construct their specific identifier.
The `get_identifier()` method is provided and uses lazy building with caching.
"""

_identifier: Optional[IdentifierT] = None

@abstractmethod
def _build_identifier(self) -> None:
def _build_identifier(self) -> IdentifierT:
"""
Build the identifier for this object.
Build and return the identifier for this object.

Subclasses must implement this method to construct their specific identifier type
and store it in an instance variable (typically `_identifier`).
Subclasses must implement this method to construct their specific identifier type.
This method is called lazily on first access via `get_identifier()`.

This method is typically called lazily on first access via `get_identifier()`.
Returns:
IdentifierT: The constructed identifier for this component.
"""
raise NotImplementedError("Subclasses must implement _build_identifier")

@abstractmethod
def get_identifier(self) -> IdentifierT:
"""
Get the typed identifier for this object.
Get the typed identifier for this object. Built lazily on first access.

Returns:
IdentifierT: The identifier for this component.
"""
...
if self._identifier is None:
self._identifier = self._build_identifier()
return self._identifier
2 changes: 1 addition & 1 deletion pyrit/identifiers/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _compute_hash(self) -> str:
"""
Compute a stable SHA256 hash from storable identifier fields.

Fields marked with metadata={"exclude_from_storage": True}, 'hash', and 'name'
Fields marked with metadata={"exclude_from_storage": True}, 'hash', and 'unique_name'
are excluded from the hash computation.

Returns:
Expand Down
13 changes: 9 additions & 4 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy.types import Uuid

from pyrit.common.utils import to_sha256
from pyrit.identifiers import ScorerIdentifier
from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier
from pyrit.models import (
AttackOutcome,
AttackResult,
Expand Down Expand Up @@ -164,7 +164,7 @@ class PromptMemoryEntry(Base):
labels: Mapped[dict[str, str]] = mapped_column(JSON)
prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON)
targeted_harm_categories: Mapped[Optional[List[str]]] = mapped_column(JSON)
converter_identifiers: Mapped[Optional[List[dict[str, str]]]] = mapped_column(JSON)
converter_identifiers: Mapped[Optional[List[Dict[str, str]]]] = mapped_column(JSON)
prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True)
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, *, entry: MessagePiece):
self.labels = entry.labels
self.prompt_metadata = entry.prompt_metadata
self.targeted_harm_categories = entry.targeted_harm_categories
self.converter_identifiers = entry.converter_identifiers
self.converter_identifiers = [conv.to_dict() for conv in entry.converter_identifiers]
self.prompt_target_identifier = entry.prompt_target_identifier
self.attack_identifier = entry.attack_identifier

Expand All @@ -230,6 +230,11 @@ def get_message_piece(self) -> MessagePiece:
Returns:
MessagePiece: The reconstructed message piece with all its data and scores.
"""
converter_ids: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = (
[ConverterIdentifier.from_dict(c) for c in self.converter_identifiers]
if self.converter_identifiers
else None
)
message_piece = MessagePiece(
role=self.role,
original_value=self.original_value,
Expand All @@ -242,7 +247,7 @@ def get_message_piece(self) -> MessagePiece:
labels=self.labels,
prompt_metadata=self.prompt_metadata,
targeted_harm_categories=self.targeted_harm_categories,
converter_identifiers=self.converter_identifiers,
converter_identifiers=converter_ids,
prompt_target_identifier=self.prompt_target_identifier,
attack_identifier=self.attack_identifier,
original_value_data_type=self.original_value_data_type,
Expand Down
23 changes: 18 additions & 5 deletions pyrit/models/message_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from uuid import uuid4

from pyrit.common.deprecation import print_deprecation_message
from pyrit.identifiers import ScorerIdentifier
from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError
from pyrit.models.score import Score

Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
sequence: int = -1,
labels: Optional[Dict[str, str]] = None,
prompt_metadata: Optional[Dict[str, Union[str, int]]] = None,
converter_identifiers: Optional[List[Dict[str, str]]] = None,
converter_identifiers: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None,
prompt_target_identifier: Optional[Dict[str, str]] = None,
attack_identifier: Optional[Dict[str, str]] = None,
scorer_identifier: Optional[Union[ScorerIdentifier, Dict[str, str]]] = None,
Expand Down Expand Up @@ -69,7 +69,8 @@ def __init__(
Because memory is how components talk with each other, this can be component specific.
e.g. the URI from a file uploaded to a blob store, or a document type you want to upload.
Defaults to None.
converter_identifiers: The converter identifiers for the prompt. Defaults to None.
converter_identifiers: The converter identifiers for the prompt. Can be ConverterIdentifier
objects or dicts (deprecated, will be removed in 0.14.0). Defaults to None.
prompt_target_identifier: The target identifier for the prompt. Defaults to None.
attack_identifier: The attack identifier for the prompt. Defaults to None.
scorer_identifier: The scorer identifier for the prompt. Can be a ScorerIdentifier or a
Expand Down Expand Up @@ -106,7 +107,19 @@ def __init__(
self.labels = labels or {}
self.prompt_metadata = prompt_metadata or {}

self.converter_identifiers = converter_identifiers if converter_identifiers else []
# Handle converter_identifiers: convert dicts to ConverterIdentifier with deprecation warning
self.converter_identifiers: List[ConverterIdentifier] = []
if converter_identifiers:
for conv_id in converter_identifiers:
if isinstance(conv_id, dict):
print_deprecation_message(
old_item="dict for converter_identifiers",
new_item="ConverterIdentifier",
removed_in="0.14.0",
)
self.converter_identifiers.append(ConverterIdentifier.from_dict(conv_id))
else:
self.converter_identifiers.append(conv_id)

self.prompt_target_identifier = prompt_target_identifier or {}
self.attack_identifier = attack_identifier or {}
Expand Down Expand Up @@ -278,7 +291,7 @@ def to_dict(self) -> dict[str, object]:
"labels": self.labels,
"targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None,
"prompt_metadata": self.prompt_metadata,
"converter_identifiers": self.converter_identifiers,
"converter_identifiers": [conv.to_dict() for conv in self.converter_identifiers],
"prompt_target_identifier": self.prompt_target_identifier,
"attack_identifier": self.attack_identifier,
"scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None,
Expand Down
Loading