diff --git a/pyrit/identifiers/identifier.py b/pyrit/identifiers/identifier.py index 3d698e717..f002ce266 100644 --- a/pyrit/identifiers/identifier.py +++ b/pyrit/identifiers/identifier.py @@ -5,17 +5,113 @@ import hashlib import json -from dataclasses import asdict, dataclass, field, fields, is_dataclass +from dataclasses import Field, asdict, dataclass, field, fields, is_dataclass +from enum import Enum from typing import Any, Literal, Type, TypeVar +import pyrit from pyrit.common.deprecation import print_deprecation_message from pyrit.identifiers.class_name_utils import class_name_to_snake_case IdentifierType = Literal["class", "instance"] + +class _ExcludeFrom(Enum): + """ + Enum specifying what a field should be excluded from. + + Used as values in the _EXCLUDE metadata set for dataclass fields. + + Values: + HASH: Exclude the field from hash computation (field is still stored). + STORAGE: Exclude the field from storage (implies HASH - field is also excluded from hash). + + The `expands_to` property returns the full set of exclusions that apply. + For example, STORAGE.expands_to returns {STORAGE, HASH} since excluding + from storage implicitly means excluding from hash as well. + """ + + HASH = "hash" + STORAGE = "storage" + + @property + def expands_to(self) -> set["_ExcludeFrom"]: + """ + Get the full set of exclusions that this value implies. + + This implements a catalog pattern where certain exclusions automatically + include others. For example, STORAGE expands to {STORAGE, HASH} because + a field excluded from storage should never be included in the hash. + + Returns: + set[_ExcludeFrom]: The complete set of exclusions including implied ones. + """ + _EXPANSION_CATALOG: dict[_ExcludeFrom, set[_ExcludeFrom]] = { + _ExcludeFrom.HASH: {_ExcludeFrom.HASH}, + _ExcludeFrom.STORAGE: {_ExcludeFrom.STORAGE, _ExcludeFrom.HASH}, + } + return _EXPANSION_CATALOG[self] + + +def _expand_exclusions(exclude_set: set[_ExcludeFrom]) -> set[_ExcludeFrom]: + """ + Expand a set of exclusions to include all implied exclusions. + + Args: + exclude_set: A set of _ExcludeFrom values. + + Returns: + set[_ExcludeFrom]: The expanded set including all implied exclusions. + """ + expanded: set[_ExcludeFrom] = set() + for exclusion in exclude_set: + expanded.update(exclusion.expands_to) + return expanded + + # Metadata keys for field configuration -EXCLUDE_FROM_STORAGE = "exclude_from_storage" -MAX_STORAGE_LENGTH = "max_storage_length" +# _EXCLUDE is a metadata key whose value is a set of _ExcludeFrom enum values. +# Examples: +# field(metadata={_EXCLUDE: {_ExcludeFrom.HASH}}) # Stored but not hashed +# field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Not stored and not hashed (STORAGE implies HASH) +_EXCLUDE = "exclude" +_MAX_STORAGE_LENGTH = "max_storage_length" + + +def _is_excluded_from_hash(f: Field[Any]) -> bool: + """ + Check if a field should be excluded from hash computation. + + A field is excluded from hash if, after expansion, the exclusion set contains _ExcludeFrom.HASH. + This uses the catalog expansion pattern where STORAGE automatically implies HASH. + + Args: + f: A dataclass field object. + + Returns: + True if the field should be excluded from hash computation. + """ + exclude_set = f.metadata.get(_EXCLUDE, set()) + expanded = _expand_exclusions(exclude_set) + return _ExcludeFrom.HASH in expanded + + +def _is_excluded_from_storage(f: Field[Any]) -> bool: + """ + Check if a field should be excluded from storage. + + A field is excluded from storage if, after expansion, the exclusion set contains _ExcludeFrom.STORAGE. + + Args: + f: A dataclass field object. + + Returns: + True if the field should be excluded from storage. + """ + exclude_set = f.metadata.get(_EXCLUDE, set()) + expanded = _expand_exclusions(exclude_set) + return _ExcludeFrom.STORAGE in expanded + T = TypeVar("T", bound="Identifier") @@ -39,14 +135,21 @@ class Identifier: class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") - # Fields excluded from storage - class_description: str = field(metadata={EXCLUDE_FROM_STORAGE: True}) - identifier_type: IdentifierType = field(metadata={EXCLUDE_FROM_STORAGE: True}) + # Fields excluded from storage (STORAGE auto-expands to include HASH) + class_description: str = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + identifier_type: IdentifierType = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Auto-computed fields - snake_class_name: str = field(init=False, metadata={EXCLUDE_FROM_STORAGE: True}) - hash: str | None = field(default=None, compare=False, kw_only=True) - unique_name: str = field(init=False) # Unique identifier: {full_snake_case}::{hash[:8]} + snake_class_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + hash: str | None = field(default=None, compare=False, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}}) + + # {full_snake_case}::{hash[:8]} + unique_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + + # Version field - stored but not hashed (allows version tracking without affecting identity) + pyrit_version: str = field( + default_factory=lambda: pyrit.__version__, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}} + ) def __post_init__(self) -> None: """Compute derived fields: snake_class_name, hash, and unique_name.""" @@ -54,26 +157,24 @@ def __post_init__(self) -> None: # 1. Compute snake_class_name object.__setattr__(self, "snake_class_name", class_name_to_snake_case(self.class_name)) # 2. Compute hash only if not already provided (e.g., from from_dict) - if self.hash is None: - object.__setattr__(self, "hash", self._compute_hash()) + computed_hash = self.hash if self.hash is not None else self._compute_hash() + object.__setattr__(self, "hash", computed_hash) # 3. Compute unique_name: full snake_case :: hash prefix full_snake = class_name_to_snake_case(self.class_name) - object.__setattr__(self, "unique_name", f"{full_snake}::{self.hash[:8]}") + object.__setattr__(self, "unique_name", f"{full_snake}::{computed_hash[:8]}") def _compute_hash(self) -> str: """ - Compute a stable SHA256 hash from storable identifier fields. + Compute a stable SHA256 hash from identifier fields not excluded from hashing. - Fields marked with metadata={"exclude_from_storage": True}, 'hash', and 'unique_name' - are excluded from the hash computation. + Fields are excluded from hash computation if they have: + metadata={_EXCLUDE: {_ExcludeFrom.HASH}} or metadata={_EXCLUDE: {_ExcludeFrom.HASH, _ExcludeFrom.STORAGE}} Returns: A hex string of the SHA256 hash. """ hashable_dict: dict[str, Any] = { - f.name: getattr(self, f.name) - for f in fields(self) - if f.name not in ("hash", "unique_name") and not f.metadata.get(EXCLUDE_FROM_STORAGE, False) + f.name: getattr(self, f.name) for f in fields(self) if not _is_excluded_from_hash(f) } config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder) return hashlib.sha256(config_json.encode("utf-8")).hexdigest() @@ -93,10 +194,10 @@ def to_dict(self) -> dict[str, Any]: """ result: dict[str, Any] = {} for f in fields(self): - if f.metadata.get(EXCLUDE_FROM_STORAGE, False): + if _is_excluded_from_storage(f): continue value = getattr(self, f.name) - max_len = f.metadata.get(MAX_STORAGE_LENGTH) + max_len = f.metadata.get(_MAX_STORAGE_LENGTH) if max_len is not None and isinstance(value, str) and len(value) > max_len: truncated = value[:max_len] field_hash = hashlib.sha256(value.encode()).hexdigest()[:16] @@ -117,12 +218,6 @@ def from_dict(cls: Type[T], data: dict[str, Any]) -> T: """ Create an Identifier from a dictionary (e.g., retrieved from database). - This handles: - - Legacy '__type__' key mapping to 'class_name' - - Legacy 'type' key mapping to 'class_name' (with deprecation warning) - - Legacy '__module__' key mapping to 'class_module' - - Ignoring unknown fields not present in the dataclass - Note: For fields with max_storage_length, stored values may be truncated strings like "... [sha256:]". If a 'hash' key is @@ -183,6 +278,28 @@ def from_dict(cls: Type[T], data: dict[str, Any]) -> T: return cls(**filtered_data) + def with_pyrit_version(self: T, version: str) -> T: + """ + Create a copy of this Identifier with a different pyrit_version. + + Since Identifier is frozen, this returns a new instance with all the same + field values except for pyrit_version which is set to the provided value. + + Args: + version: The pyrit_version to set on the new instance. + + Returns: + A new Identifier instance with the updated pyrit_version. + """ + # Get all current field values + current_data = self.to_dict() + # Override pyrit_version + current_data["pyrit_version"] = version + # Add back fields excluded from storage that are needed for from_dict + current_data["class_description"] = self.class_description + current_data["identifier_type"] = self.identifier_type + return type(self).from_dict(current_data) + @classmethod def normalize(cls: Type[T], value: T | dict[str, Any]) -> T: """ diff --git a/pyrit/identifiers/scorer_identifier.py b/pyrit/identifiers/scorer_identifier.py index 7707e7b22..d467504fe 100644 --- a/pyrit/identifiers/scorer_identifier.py +++ b/pyrit/identifiers/scorer_identifier.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Type, cast -from pyrit.identifiers.identifier import MAX_STORAGE_LENGTH, Identifier +from pyrit.identifiers.identifier import _MAX_STORAGE_LENGTH, Identifier from pyrit.models.score import ScoreType @@ -22,10 +22,10 @@ class ScorerIdentifier(Identifier): scorer_type: ScoreType = "unknown" """The type of scorer ("true_false", "float_scale", or "unknown").""" - system_prompt_template: Optional[str] = field(default=None, metadata={MAX_STORAGE_LENGTH: 100}) + system_prompt_template: Optional[str] = field(default=None, metadata={_MAX_STORAGE_LENGTH: 100}) """The system prompt template used by the scorer. Truncated for storage if > 100 characters.""" - user_prompt_template: Optional[str] = field(default=None, metadata={MAX_STORAGE_LENGTH: 100}) + user_prompt_template: Optional[str] = field(default=None, metadata={_MAX_STORAGE_LENGTH: 100}) """The user prompt template used by the scorer. Truncated for storage if > 100 characters.""" sub_identifier: Optional[List["ScorerIdentifier"]] = None diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index dd787f718..d36b140c2 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -30,6 +30,7 @@ ) from sqlalchemy.types import Uuid +import pyrit from pyrit.common.utils import to_sha256 from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier from pyrit.models import ( @@ -50,6 +51,9 @@ SeedType, ) +# Default pyrit_version for database records created before version tracking was added +LEGACY_PYRIT_VERSION = "<0.10.0" + class CustomUUID(TypeDecorator[uuid.UUID]): """ @@ -185,6 +189,10 @@ class PromptMemoryEntry(Base): original_prompt_id = mapped_column(CustomUUID, nullable=False) + # Version of PyRIT used when this entry was created + # Nullable for backwards compatibility with existing databases + pyrit_version = mapped_column(String, nullable=True) + scores: Mapped[List["ScoreEntry"]] = relationship( "ScoreEntry", primaryjoin="ScoreEntry.prompt_request_response_id == PromptMemoryEntry.original_prompt_id", @@ -222,6 +230,7 @@ def __init__(self, *, entry: MessagePiece): self.response_error = entry.response_error # type: ignore self.original_prompt_id = entry.original_prompt_id + self.pyrit_version = pyrit.__version__ def get_message_piece(self) -> MessagePiece: """ @@ -230,11 +239,14 @@ def get_message_piece(self) -> MessagePiece: Returns: MessagePiece: The reconstructed message piece with all its data and scores. """ - converter_ids: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = ( - [ConverterIdentifier.from_dict(c) for c in self.converter_identifiers] - if self.converter_identifiers - else None - ) + # Reconstruct ConverterIdentifiers with the stored pyrit_version + converter_ids: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None + stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION + if self.converter_identifiers: + converter_ids = [] + for c in self.converter_identifiers: + converter = ConverterIdentifier.from_dict(c) + converter_ids.append(converter.with_pyrit_version(stored_version)) message_piece = MessagePiece( role=self.role, original_value=self.original_value, @@ -321,6 +333,9 @@ class ScoreEntry(Base): timestamp = mapped_column(DateTime, nullable=False) task = mapped_column(String, nullable=True) # Deprecated: Use objective instead objective = mapped_column(String, nullable=True) + # Version of PyRIT used when this score was created + # Nullable for backwards compatibility with existing databases + pyrit_version = mapped_column(String, nullable=True) prompt_request_piece: Mapped["PromptMemoryEntry"] = relationship("PromptMemoryEntry", back_populates="scores") def __init__(self, *, entry: Score): @@ -346,6 +361,7 @@ def __init__(self, *, entry: Score): # New code should only read from objective self.task = entry.objective self.objective = entry.objective + self.pyrit_version = pyrit.__version__ def get_score(self) -> Score: """ @@ -354,10 +370,12 @@ def get_score(self) -> Score: Returns: Score: The reconstructed score object with all its data. """ - # Convert dict back to ScorerIdentifier (Score.__init__ handles None by creating default) - scorer_identifier = ( - ScorerIdentifier.from_dict(self.scorer_class_identifier) if self.scorer_class_identifier else None - ) + # Convert dict back to ScorerIdentifier with the stored pyrit_version + scorer_identifier = None + stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION + if self.scorer_class_identifier: + scorer_identifier = ScorerIdentifier.from_dict(self.scorer_class_identifier) + scorer_identifier = scorer_identifier.with_pyrit_version(stored_version) return Score( id=self.id, score_value=self.score_value, @@ -677,6 +695,9 @@ class AttackResultEntry(Base): pruned_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) adversarial_chat_conversation_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) + # Version of PyRIT used when this attack result was created + # Nullable for backwards compatibility with existing databases + pyrit_version = mapped_column(String, nullable=True) last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship( "PromptMemoryEntry", @@ -720,6 +741,7 @@ def __init__(self, *, entry: AttackResult): ] or None self.timestamp = datetime.now() + self.pyrit_version = pyrit.__version__ @staticmethod def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: @@ -910,21 +932,23 @@ def get_scenario_result(self) -> ScenarioResult: ScenarioResult object with scenario metadata but empty attack_results """ # Recreate ScenarioIdentifier with the stored pyrit_version + stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION scenario_identifier = ScenarioIdentifier( name=self.scenario_name, description=self.scenario_description or "", scenario_version=self.scenario_version, init_data=self.scenario_init_data, - pyrit_version=self.pyrit_version, + pyrit_version=stored_version, ) # Return empty attack_results - will be populated by memory_interface attack_results: dict[str, list[AttackResult]] = {} - # Convert dict back to ScorerIdentifier for reconstruction - scorer_identifier = ( - ScorerIdentifier.from_dict(self.objective_scorer_identifier) if self.objective_scorer_identifier else None - ) + # Convert dict back to ScorerIdentifier with the stored pyrit_version + scorer_identifier = None + if self.objective_scorer_identifier: + scorer_identifier = ScorerIdentifier.from_dict(self.objective_scorer_identifier) + scorer_identifier = scorer_identifier.with_pyrit_version(stored_version) return ScenarioResult( id=self.id, diff --git a/tests/unit/identifiers/test_identifiers.py b/tests/unit/identifiers/test_identifiers.py index b5e0a0969..a9e373a7f 100644 --- a/tests/unit/identifiers/test_identifiers.py +++ b/tests/unit/identifiers/test_identifiers.py @@ -5,7 +5,9 @@ import pytest +import pyrit from pyrit.identifiers import Identifier, LegacyIdentifiable +from pyrit.identifiers.identifier import _EXCLUDE, _ExcludeFrom, _expand_exclusions class TestLegacyIdentifiable: @@ -182,7 +184,7 @@ class TestIdentifierStorage: """Tests for Identifier storage functionality.""" def test_to_dict_excludes_marked_fields(self): - """Test that to_dict excludes fields marked with exclude_from_storage.""" + """Test that to_dict excludes fields marked with _EXCLUDE containing _ExcludeFrom.STORAGE.""" identifier = Identifier( identifier_type="class", class_name="TestClass", @@ -192,14 +194,16 @@ def test_to_dict_excludes_marked_fields(self): storage_dict = identifier.to_dict() # Should include storable fields - assert "unique_name" in storage_dict assert "class_name" in storage_dict assert "class_module" in storage_dict assert "hash" in storage_dict + assert "pyrit_version" in storage_dict - # Should exclude non-storable fields + # Should exclude non-storable fields (marked with _ExcludeFrom.STORAGE) assert "class_description" not in storage_dict assert "identifier_type" not in storage_dict + assert "snake_class_name" not in storage_dict + assert "unique_name" not in storage_dict def test_to_dict_values_match(self): """Test that to_dict values match the original identifier.""" @@ -211,8 +215,10 @@ def test_to_dict_values_match(self): ) storage_dict = identifier.to_dict() - # unique_name is auto-computed - assert storage_dict["unique_name"] == identifier.unique_name + # unique_name and snake_class_name are excluded from storage + assert "unique_name" not in storage_dict + assert "snake_class_name" not in storage_dict + # Stored fields match assert storage_dict["class_name"] == "MyScorer" assert storage_dict["class_module"] == "pyrit.score.my_scorer" assert storage_dict["hash"] == identifier.hash @@ -226,7 +232,7 @@ def test_subclass_inherits_hash_computation(self): @dataclass(frozen=True) class ExtendedIdentifier(Identifier): - extra_field: str + extra_field: str = field(kw_only=True) extended = ExtendedIdentifier( class_name="TestClass", @@ -243,7 +249,7 @@ def test_subclass_extra_fields_included_in_hash(self): @dataclass(frozen=True) class ExtendedIdentifier(Identifier): - extra_field: str + extra_field: str = field(kw_only=True) extended1 = ExtendedIdentifier( class_name="TestClass", @@ -263,11 +269,12 @@ class ExtendedIdentifier(Identifier): assert extended1.hash != extended2.hash def test_subclass_excluded_fields_not_in_hash(self): - """Test that subclass fields marked exclude_from_storage are excluded from hash.""" + """Test that subclass fields with _ExcludeFrom.STORAGE in _EXCLUDE are excluded from hash via expansion.""" @dataclass(frozen=True) class ExtendedIdentifier(Identifier): - display_only: str = field(default="", metadata={"exclude_from_storage": True}) + # Only need STORAGE - HASH is implied via expansion + display_only: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) extended1 = ExtendedIdentifier( class_name="TestClass", @@ -291,8 +298,9 @@ def test_subclass_to_dict_includes_extra_storable_fields(self): @dataclass(frozen=True) class ExtendedIdentifier(Identifier): - extra_field: str - display_only: str = field(default="", metadata={"exclude_from_storage": True}) + extra_field: str = field(kw_only=True) + # Only need STORAGE - HASH is implied via expansion + display_only: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) extended = ExtendedIdentifier( class_name="TestClass", @@ -378,3 +386,198 @@ def test_from_dict_computes_hash_when_not_provided(self): assert len(identifier.hash) == 64 # unique_name should use the computed hash assert identifier.unique_name == f"test_class::{identifier.hash[:8]}" + + +class TestPyritVersion: + """Tests for the pyrit_version field on Identifier.""" + + def test_pyrit_version_is_set_by_default(self): + """Test that pyrit_version is automatically set to the current pyrit version.""" + identifier = Identifier( + identifier_type="class", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + ) + assert identifier.pyrit_version == pyrit.__version__ + + def test_pyrit_version_can_be_overridden(self): + """Test that pyrit_version can be explicitly provided.""" + identifier = Identifier( + identifier_type="class", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + pyrit_version="1.0.0", + ) + assert identifier.pyrit_version == "1.0.0" + + def test_pyrit_version_is_excluded_from_hash(self): + """Test that pyrit_version is excluded from hash computation.""" + identifier1 = Identifier( + identifier_type="class", + class_name="TestClass", + class_module="test.module", + class_description="Description", + pyrit_version="1.0.0", + ) + identifier2 = Identifier( + identifier_type="class", + class_name="TestClass", + class_module="test.module", + class_description="Description", + pyrit_version="2.0.0", + ) + # Hash should be the same since pyrit_version is excluded + assert identifier1.hash == identifier2.hash + + def test_pyrit_version_is_included_in_storage(self): + """Test that pyrit_version is included in to_dict output.""" + identifier = Identifier( + identifier_type="class", + class_name="TestClass", + class_module="test.module", + class_description="A test description", + pyrit_version="1.0.0", + ) + storage_dict = identifier.to_dict() + assert "pyrit_version" in storage_dict + assert storage_dict["pyrit_version"] == "1.0.0" + + def test_from_dict_preserves_pyrit_version(self): + """Test that from_dict preserves the pyrit_version from the dict.""" + data = { + "class_name": "TestClass", + "class_module": "test.module", + "pyrit_version": "0.5.0", + } + + identifier = Identifier.from_dict(data) + assert identifier.pyrit_version == "0.5.0" + + def test_from_dict_defaults_pyrit_version_when_missing(self): + """Test that from_dict defaults pyrit_version to current version when not in dict.""" + data = { + "class_name": "TestClass", + "class_module": "test.module", + } + + identifier = Identifier.from_dict(data) + assert identifier.pyrit_version == pyrit.__version__ + + +class TestExcludeMetadata: + """Tests for the _EXCLUDE metadata field configuration.""" + + def test_storage_exclusion_implies_hash_exclusion(self): + """Test that STORAGE exclusion automatically implies HASH exclusion via expansion. + + This validates the catalog expansion pattern where _ExcludeFrom.STORAGE.expands_to + returns {STORAGE, HASH}, ensuring fields excluded from storage are also excluded + from hash computation. + """ + # Verify the expansion catalog works correctly + assert _ExcludeFrom.HASH in _ExcludeFrom.STORAGE.expands_to + assert _ExcludeFrom.STORAGE in _ExcludeFrom.STORAGE.expands_to + + # Verify HASH only expands to itself + assert _ExcludeFrom.HASH.expands_to == {_ExcludeFrom.HASH} + + # Verify _expand_exclusions works on sets + expanded = _expand_exclusions({_ExcludeFrom.STORAGE}) + assert _ExcludeFrom.HASH in expanded + assert _ExcludeFrom.STORAGE in expanded + + def test_subclass_storage_only_exclusion_works_via_expansion(self): + """Test that subclass fields with only _ExcludeFrom.STORAGE work correctly via expansion. + + With the catalog expansion pattern, specifying only STORAGE automatically implies HASH. + This is the recommended pattern - no need to explicitly specify both. + """ + + @dataclass(frozen=True) + class ValidIdentifier(Identifier): + # Only STORAGE is needed - HASH is implied via expansion + transient_field: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + + id1 = ValidIdentifier( + class_name="Test", + class_module="test", + identifier_type="class", + class_description="desc", + transient_field="value1", + ) + id2 = ValidIdentifier( + class_name="Test", + class_module="test", + identifier_type="class", + class_description="desc", + transient_field="value2", + ) + + # Hash should be the same (HASH exclusion is implied by STORAGE) + assert id1.hash == id2.hash + + # Field should not be in storage + assert "transient_field" not in id1.to_dict() + + def test_hash_only_exclusion_works(self): + """Test that a field can be excluded from hash only (still stored).""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + metadata_field: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.HASH}}) + + extended1 = ExtendedIdentifier( + class_name="TestClass", + class_module="test.module", + identifier_type="class", + class_description="Description", + metadata_field="value1", + ) + extended2 = ExtendedIdentifier( + class_name="TestClass", + class_module="test.module", + identifier_type="class", + class_description="Description", + metadata_field="value2", + ) + + # Hash should be the same since metadata_field is excluded from hash + assert extended1.hash == extended2.hash + + # But both values should be in storage + storage1 = extended1.to_dict() + storage2 = extended2.to_dict() + assert storage1["metadata_field"] == "value1" + assert storage2["metadata_field"] == "value2" + + def test_storage_exclusion_works(self): + """Test that a field excluded from storage is also excluded from hash via expansion.""" + + @dataclass(frozen=True) + class ExtendedIdentifier(Identifier): + # Only need STORAGE - HASH is implied via expansion + transient_field: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + + extended1 = ExtendedIdentifier( + class_name="TestClass", + class_module="test.module", + identifier_type="class", + class_description="Description", + transient_field="value1", + ) + extended2 = ExtendedIdentifier( + class_name="TestClass", + class_module="test.module", + identifier_type="class", + class_description="Description", + transient_field="value2", + ) + + # Hash should be the same since transient_field is excluded from hash + assert extended1.hash == extended2.hash + + # Field should not be in storage + storage1 = extended1.to_dict() + assert "transient_field" not in storage1 diff --git a/tests/unit/identifiers/test_scorer_identifier.py b/tests/unit/identifiers/test_scorer_identifier.py index 822b678f0..033c24bee 100644 --- a/tests/unit/identifiers/test_scorer_identifier.py +++ b/tests/unit/identifiers/test_scorer_identifier.py @@ -162,7 +162,8 @@ def test_to_dict_basic(self): assert result["class_name"] == "TestScorer" assert result["class_module"] == "pyrit.score.test_scorer" assert result["hash"] == identifier.hash - assert result["unique_name"] == identifier.unique_name + # unique_name is excluded from storage (has _ExcludeFrom.STORAGE metadata) + assert "unique_name" not in result # class_description and identifier_type should be excluded assert "class_description" not in result assert "identifier_type" not in result