Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 143 additions & 26 deletions pyrit/identifiers/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,113 @@

import hashlib
import json
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from dataclasses import Field, asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Any, Literal, Type, TypeVar

import pyrit
from pyrit.common.deprecation import print_deprecation_message
from pyrit.identifiers.class_name_utils import class_name_to_snake_case

IdentifierType = Literal["class", "instance"]


class _ExcludeFrom(Enum):
"""
Enum specifying what a field should be excluded from.

Used as values in the _EXCLUDE metadata set for dataclass fields.

Values:
HASH: Exclude the field from hash computation (field is still stored).
STORAGE: Exclude the field from storage (implies HASH - field is also excluded from hash).

The `expands_to` property returns the full set of exclusions that apply.
For example, STORAGE.expands_to returns {STORAGE, HASH} since excluding
from storage implicitly means excluding from hash as well.
"""

HASH = "hash"
STORAGE = "storage"

@property
def expands_to(self) -> set["_ExcludeFrom"]:
"""
Get the full set of exclusions that this value implies.

This implements a catalog pattern where certain exclusions automatically
include others. For example, STORAGE expands to {STORAGE, HASH} because
a field excluded from storage should never be included in the hash.

Returns:
set[_ExcludeFrom]: The complete set of exclusions including implied ones.
"""
_EXPANSION_CATALOG: dict[_ExcludeFrom, set[_ExcludeFrom]] = {
_ExcludeFrom.HASH: {_ExcludeFrom.HASH},
_ExcludeFrom.STORAGE: {_ExcludeFrom.STORAGE, _ExcludeFrom.HASH},
}
return _EXPANSION_CATALOG[self]


def _expand_exclusions(exclude_set: set[_ExcludeFrom]) -> set[_ExcludeFrom]:
"""
Expand a set of exclusions to include all implied exclusions.

Args:
exclude_set: A set of _ExcludeFrom values.

Returns:
set[_ExcludeFrom]: The expanded set including all implied exclusions.
"""
expanded: set[_ExcludeFrom] = set()
for exclusion in exclude_set:
expanded.update(exclusion.expands_to)
return expanded


# Metadata keys for field configuration
EXCLUDE_FROM_STORAGE = "exclude_from_storage"
MAX_STORAGE_LENGTH = "max_storage_length"
# _EXCLUDE is a metadata key whose value is a set of _ExcludeFrom enum values.
# Examples:
# field(metadata={_EXCLUDE: {_ExcludeFrom.HASH}}) # Stored but not hashed
# field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Not stored and not hashed (STORAGE implies HASH)
_EXCLUDE = "exclude"
_MAX_STORAGE_LENGTH = "max_storage_length"


def _is_excluded_from_hash(f: Field[Any]) -> bool:
"""
Check if a field should be excluded from hash computation.

A field is excluded from hash if, after expansion, the exclusion set contains _ExcludeFrom.HASH.
This uses the catalog expansion pattern where STORAGE automatically implies HASH.

Args:
f: A dataclass field object.

Returns:
True if the field should be excluded from hash computation.
"""
exclude_set = f.metadata.get(_EXCLUDE, set())
expanded = _expand_exclusions(exclude_set)
return _ExcludeFrom.HASH in expanded


def _is_excluded_from_storage(f: Field[Any]) -> bool:
"""
Check if a field should be excluded from storage.

A field is excluded from storage if, after expansion, the exclusion set contains _ExcludeFrom.STORAGE.

Args:
f: A dataclass field object.

Returns:
True if the field should be excluded from storage.
"""
exclude_set = f.metadata.get(_EXCLUDE, set())
expanded = _expand_exclusions(exclude_set)
return _ExcludeFrom.STORAGE in expanded


T = TypeVar("T", bound="Identifier")

Expand All @@ -39,41 +135,46 @@ class Identifier:
class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer")
class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer")

# Fields excluded from storage
class_description: str = field(metadata={EXCLUDE_FROM_STORAGE: True})
identifier_type: IdentifierType = field(metadata={EXCLUDE_FROM_STORAGE: True})
# Fields excluded from storage (STORAGE auto-expands to include HASH)
class_description: str = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
identifier_type: IdentifierType = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})

# Auto-computed fields
snake_class_name: str = field(init=False, metadata={EXCLUDE_FROM_STORAGE: True})
hash: str | None = field(default=None, compare=False, kw_only=True)
unique_name: str = field(init=False) # Unique identifier: {full_snake_case}::{hash[:8]}
snake_class_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
hash: str | None = field(default=None, compare=False, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}})

# {full_snake_case}::{hash[:8]}
unique_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})

# Version field - stored but not hashed (allows version tracking without affecting identity)
pyrit_version: str = field(
default_factory=lambda: pyrit.__version__, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}}
)

def __post_init__(self) -> None:
"""Compute derived fields: snake_class_name, hash, and unique_name."""
# Use object.__setattr__ since this is a frozen dataclass
# 1. Compute snake_class_name
object.__setattr__(self, "snake_class_name", class_name_to_snake_case(self.class_name))
# 2. Compute hash only if not already provided (e.g., from from_dict)
if self.hash is None:
object.__setattr__(self, "hash", self._compute_hash())
computed_hash = self.hash if self.hash is not None else self._compute_hash()
object.__setattr__(self, "hash", computed_hash)
# 3. Compute unique_name: full snake_case :: hash prefix
full_snake = class_name_to_snake_case(self.class_name)
object.__setattr__(self, "unique_name", f"{full_snake}::{self.hash[:8]}")
object.__setattr__(self, "unique_name", f"{full_snake}::{computed_hash[:8]}")

def _compute_hash(self) -> str:
"""
Compute a stable SHA256 hash from storable identifier fields.
Compute a stable SHA256 hash from identifier fields not excluded from hashing.

Fields marked with metadata={"exclude_from_storage": True}, 'hash', and 'unique_name'
are excluded from the hash computation.
Fields are excluded from hash computation if they have:
metadata={_EXCLUDE: {_ExcludeFrom.HASH}} or metadata={_EXCLUDE: {_ExcludeFrom.HASH, _ExcludeFrom.STORAGE}}

Returns:
A hex string of the SHA256 hash.
"""
hashable_dict: dict[str, Any] = {
f.name: getattr(self, f.name)
for f in fields(self)
if f.name not in ("hash", "unique_name") and not f.metadata.get(EXCLUDE_FROM_STORAGE, False)
f.name: getattr(self, f.name) for f in fields(self) if not _is_excluded_from_hash(f)
}
config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder)
return hashlib.sha256(config_json.encode("utf-8")).hexdigest()
Expand All @@ -93,10 +194,10 @@ def to_dict(self) -> dict[str, Any]:
"""
result: dict[str, Any] = {}
for f in fields(self):
if f.metadata.get(EXCLUDE_FROM_STORAGE, False):
if _is_excluded_from_storage(f):
continue
value = getattr(self, f.name)
max_len = f.metadata.get(MAX_STORAGE_LENGTH)
max_len = f.metadata.get(_MAX_STORAGE_LENGTH)
if max_len is not None and isinstance(value, str) and len(value) > max_len:
truncated = value[:max_len]
field_hash = hashlib.sha256(value.encode()).hexdigest()[:16]
Expand All @@ -117,12 +218,6 @@ def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""
Create an Identifier from a dictionary (e.g., retrieved from database).

This handles:
- Legacy '__type__' key mapping to 'class_name'
- Legacy 'type' key mapping to 'class_name' (with deprecation warning)
- Legacy '__module__' key mapping to 'class_module'
- Ignoring unknown fields not present in the dataclass

Note:
For fields with max_storage_length, stored values may be truncated
strings like "<first N chars>... [sha256:<hash>]". If a 'hash' key is
Expand Down Expand Up @@ -183,6 +278,28 @@ def from_dict(cls: Type[T], data: dict[str, Any]) -> T:

return cls(**filtered_data)

def with_pyrit_version(self: T, version: str) -> T:
"""
Create a copy of this Identifier with a different pyrit_version.

Since Identifier is frozen, this returns a new instance with all the same
field values except for pyrit_version which is set to the provided value.

Args:
version: The pyrit_version to set on the new instance.

Returns:
A new Identifier instance with the updated pyrit_version.
"""
# Get all current field values
current_data = self.to_dict()
# Override pyrit_version
current_data["pyrit_version"] = version
# Add back fields excluded from storage that are needed for from_dict
current_data["class_description"] = self.class_description
current_data["identifier_type"] = self.identifier_type
return type(self).from_dict(current_data)

@classmethod
def normalize(cls: Type[T], value: T | dict[str, Any]) -> T:
"""
Expand Down
6 changes: 3 additions & 3 deletions pyrit/identifiers/scorer_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, cast

from pyrit.identifiers.identifier import MAX_STORAGE_LENGTH, Identifier
from pyrit.identifiers.identifier import _MAX_STORAGE_LENGTH, Identifier
from pyrit.models.score import ScoreType


Expand All @@ -22,10 +22,10 @@ class ScorerIdentifier(Identifier):
scorer_type: ScoreType = "unknown"
"""The type of scorer ("true_false", "float_scale", or "unknown")."""

system_prompt_template: Optional[str] = field(default=None, metadata={MAX_STORAGE_LENGTH: 100})
system_prompt_template: Optional[str] = field(default=None, metadata={_MAX_STORAGE_LENGTH: 100})
"""The system prompt template used by the scorer. Truncated for storage if > 100 characters."""

user_prompt_template: Optional[str] = field(default=None, metadata={MAX_STORAGE_LENGTH: 100})
user_prompt_template: Optional[str] = field(default=None, metadata={_MAX_STORAGE_LENGTH: 100})
"""The user prompt template used by the scorer. Truncated for storage if > 100 characters."""

sub_identifier: Optional[List["ScorerIdentifier"]] = None
Expand Down
52 changes: 38 additions & 14 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from sqlalchemy.types import Uuid

import pyrit
from pyrit.common.utils import to_sha256
from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier
from pyrit.models import (
Expand All @@ -50,6 +51,9 @@
SeedType,
)

# Default pyrit_version for database records created before version tracking was added
LEGACY_PYRIT_VERSION = "<0.10.0"


class CustomUUID(TypeDecorator[uuid.UUID]):
"""
Expand Down Expand Up @@ -185,6 +189,10 @@ class PromptMemoryEntry(Base):

original_prompt_id = mapped_column(CustomUUID, nullable=False)

# Version of PyRIT used when this entry was created
# Nullable for backwards compatibility with existing databases
pyrit_version = mapped_column(String, nullable=True)

scores: Mapped[List["ScoreEntry"]] = relationship(
"ScoreEntry",
primaryjoin="ScoreEntry.prompt_request_response_id == PromptMemoryEntry.original_prompt_id",
Expand Down Expand Up @@ -222,6 +230,7 @@ def __init__(self, *, entry: MessagePiece):
self.response_error = entry.response_error # type: ignore

self.original_prompt_id = entry.original_prompt_id
self.pyrit_version = pyrit.__version__

def get_message_piece(self) -> MessagePiece:
"""
Expand All @@ -230,11 +239,14 @@ def get_message_piece(self) -> MessagePiece:
Returns:
MessagePiece: The reconstructed message piece with all its data and scores.
"""
converter_ids: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = (
[ConverterIdentifier.from_dict(c) for c in self.converter_identifiers]
if self.converter_identifiers
else None
)
# Reconstruct ConverterIdentifiers with the stored pyrit_version
converter_ids: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
if self.converter_identifiers:
converter_ids = []
for c in self.converter_identifiers:
converter = ConverterIdentifier.from_dict(c)
converter_ids.append(converter.with_pyrit_version(stored_version))
message_piece = MessagePiece(
role=self.role,
original_value=self.original_value,
Expand Down Expand Up @@ -321,6 +333,9 @@ class ScoreEntry(Base):
timestamp = mapped_column(DateTime, nullable=False)
task = mapped_column(String, nullable=True) # Deprecated: Use objective instead
objective = mapped_column(String, nullable=True)
# Version of PyRIT used when this score was created
# Nullable for backwards compatibility with existing databases
pyrit_version = mapped_column(String, nullable=True)
prompt_request_piece: Mapped["PromptMemoryEntry"] = relationship("PromptMemoryEntry", back_populates="scores")

def __init__(self, *, entry: Score):
Expand All @@ -346,6 +361,7 @@ def __init__(self, *, entry: Score):
# New code should only read from objective
self.task = entry.objective
self.objective = entry.objective
self.pyrit_version = pyrit.__version__

def get_score(self) -> Score:
"""
Expand All @@ -354,10 +370,12 @@ def get_score(self) -> Score:
Returns:
Score: The reconstructed score object with all its data.
"""
# Convert dict back to ScorerIdentifier (Score.__init__ handles None by creating default)
scorer_identifier = (
ScorerIdentifier.from_dict(self.scorer_class_identifier) if self.scorer_class_identifier else None
)
# Convert dict back to ScorerIdentifier with the stored pyrit_version
scorer_identifier = None
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
if self.scorer_class_identifier:
scorer_identifier = ScorerIdentifier.from_dict(self.scorer_class_identifier)
scorer_identifier = scorer_identifier.with_pyrit_version(stored_version)
return Score(
id=self.id,
score_value=self.score_value,
Expand Down Expand Up @@ -677,6 +695,9 @@ class AttackResultEntry(Base):
pruned_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
timestamp = mapped_column(DateTime, nullable=False)
# Version of PyRIT used when this attack result was created
# Nullable for backwards compatibility with existing databases
pyrit_version = mapped_column(String, nullable=True)

last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship(
"PromptMemoryEntry",
Expand Down Expand Up @@ -720,6 +741,7 @@ def __init__(self, *, entry: AttackResult):
] or None

self.timestamp = datetime.now()
self.pyrit_version = pyrit.__version__

@staticmethod
def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]:
Expand Down Expand Up @@ -910,21 +932,23 @@ def get_scenario_result(self) -> ScenarioResult:
ScenarioResult object with scenario metadata but empty attack_results
"""
# Recreate ScenarioIdentifier with the stored pyrit_version
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
scenario_identifier = ScenarioIdentifier(
name=self.scenario_name,
description=self.scenario_description or "",
scenario_version=self.scenario_version,
init_data=self.scenario_init_data,
pyrit_version=self.pyrit_version,
pyrit_version=stored_version,
)

# Return empty attack_results - will be populated by memory_interface
attack_results: dict[str, list[AttackResult]] = {}

# Convert dict back to ScorerIdentifier for reconstruction
scorer_identifier = (
ScorerIdentifier.from_dict(self.objective_scorer_identifier) if self.objective_scorer_identifier else None
)
# Convert dict back to ScorerIdentifier with the stored pyrit_version
scorer_identifier = None
if self.objective_scorer_identifier:
scorer_identifier = ScorerIdentifier.from_dict(self.objective_scorer_identifier)
scorer_identifier = scorer_identifier.with_pyrit_version(stored_version)

return ScenarioResult(
id=self.id,
Expand Down
Loading