diff --git a/src/lean_spec/subspecs/networking/enr/eth2.py b/src/lean_spec/subspecs/networking/enr/eth2.py index 231e344c..7dc5437f 100644 --- a/src/lean_spec/subspecs/networking/enr/eth2.py +++ b/src/lean_spec/subspecs/networking/enr/eth2.py @@ -26,7 +26,7 @@ from typing import ClassVar -from lean_spec.subspecs.networking.types import ForkDigest +from lean_spec.subspecs.networking.types import ForkDigest, Version from lean_spec.types import StrictBaseModel, Uint64 from lean_spec.types.bitfields import BaseBitvector from lean_spec.types.boolean import Boolean @@ -45,18 +45,18 @@ class Eth2Data(StrictBaseModel): fork_digest: ForkDigest """Current active fork identifier (4 bytes).""" - next_fork_version: ForkDigest - """Fork version of next scheduled fork. Equals current if none scheduled.""" + next_fork_version: Version + """Fork version of next scheduled fork. Equals current version if none scheduled.""" next_fork_epoch: Uint64 """Epoch when next fork activates. FAR_FUTURE_EPOCH if none scheduled.""" @classmethod - def no_scheduled_fork(cls, current_digest: ForkDigest) -> "Eth2Data": - """Create Eth2Data with no scheduled fork.""" + def no_scheduled_fork(cls, current_digest: ForkDigest, current_version: Version) -> "Eth2Data": + """Create Eth2Data indicating no scheduled fork.""" return cls( fork_digest=current_digest, - next_fork_version=current_digest, + next_fork_version=current_version, next_fork_epoch=FAR_FUTURE_EPOCH, ) @@ -74,32 +74,32 @@ class AttestationSubnets(BaseBitvector): @classmethod def none(cls) -> "AttestationSubnets": """No subscriptions.""" - return cls(data=[Boolean(False)] * 64) + return cls(data=[Boolean(False)] * cls.LENGTH) @classmethod def all(cls) -> "AttestationSubnets": """Subscribe to all 64 subnets.""" - return cls(data=[Boolean(True)] * 64) + return cls(data=[Boolean(True)] * cls.LENGTH) @classmethod def from_subnet_ids(cls, subnet_ids: list[int]) -> "AttestationSubnets": """Subscribe to specific subnets.""" - bits = [Boolean(False)] * 64 + bits = [Boolean(False)] * cls.LENGTH for sid in subnet_ids: - if not 0 <= sid < 64: + if not 0 <= sid < cls.LENGTH: raise ValueError(f"Subnet ID must be 0-63, got {sid}") bits[sid] = Boolean(True) return cls(data=bits) def is_subscribed(self, subnet_id: int) -> bool: """Check if subscribed to a subnet.""" - if not 0 <= subnet_id < 64: + if not 0 <= subnet_id < self.LENGTH: raise ValueError(f"Subnet ID must be 0-63, got {subnet_id}") return bool(self.data[subnet_id]) def subscribed_subnets(self) -> list[int]: """List of subscribed subnet IDs.""" - return [i for i in range(64) if self.data[i]] + return [i for i in range(self.LENGTH) if self.data[i]] def subscription_count(self) -> int: """Number of subscribed subnets.""" @@ -119,15 +119,33 @@ class SyncCommitteeSubnets(BaseBitvector): @classmethod def none(cls) -> "SyncCommitteeSubnets": """No subscriptions.""" - return cls(data=[Boolean(False)] * 4) + return cls(data=[Boolean(False)] * cls.LENGTH) @classmethod def all(cls) -> "SyncCommitteeSubnets": """Subscribe to all 4 subnets.""" - return cls(data=[Boolean(True)] * 4) + return cls(data=[Boolean(True)] * cls.LENGTH) + + @classmethod + def from_subnet_ids(cls, subnet_ids: list[int]) -> "SyncCommitteeSubnets": + """Subscribe to specific sync subnets.""" + bits = [Boolean(False)] * cls.LENGTH + for sid in subnet_ids: + if not 0 <= sid < cls.LENGTH: + raise ValueError(f"Sync subnet ID must be 0-3, got {sid}") + bits[sid] = Boolean(True) + return cls(data=bits) def is_subscribed(self, subnet_id: int) -> bool: """Check if subscribed to a sync subnet.""" - if not 0 <= subnet_id < 4: + if not 0 <= subnet_id < self.LENGTH: raise ValueError(f"Sync subnet ID must be 0-3, got {subnet_id}") return bool(self.data[subnet_id]) + + def subscribed_subnets(self) -> list[int]: + """List of subscribed sync subnet IDs.""" + return [i for i in range(self.LENGTH) if self.data[i]] + + def subscription_count(self) -> int: + """Number of subscribed sync subnets.""" + return sum(1 for b in self.data if b) diff --git a/src/lean_spec/subspecs/networking/types.py b/src/lean_spec/subspecs/networking/types.py index 98c59334..373919aa 100644 --- a/src/lean_spec/subspecs/networking/types.py +++ b/src/lean_spec/subspecs/networking/types.py @@ -20,6 +20,9 @@ ForkDigest = Bytes4 """4-byte fork identifier ensuring network isolation between forks.""" +Version = Bytes4 +"""4-byte fork version number (e.g., 0x01000000 for Phase0).""" + SeqNumber = Uint64 """Sequence number used in ENR records, metadata, and ping messages.""" diff --git a/tests/lean_spec/subspecs/networking/enr/test_eth2.py b/tests/lean_spec/subspecs/networking/enr/test_eth2.py index 8f200fc1..b04737b3 100644 --- a/tests/lean_spec/subspecs/networking/enr/test_eth2.py +++ b/tests/lean_spec/subspecs/networking/enr/test_eth2.py @@ -4,7 +4,11 @@ from pydantic import ValidationError from lean_spec.subspecs.networking.enr import Eth2Data -from lean_spec.subspecs.networking.enr.eth2 import AttestationSubnets, SyncCommitteeSubnets +from lean_spec.subspecs.networking.enr.eth2 import ( + FAR_FUTURE_EPOCH, + AttestationSubnets, + SyncCommitteeSubnets, +) from lean_spec.types import Uint64 from lean_spec.types.byte_arrays import Bytes4 @@ -25,11 +29,12 @@ def test_create_eth2_data(self) -> None: def test_no_scheduled_fork_factory(self) -> None: """no_scheduled_fork factory creates correct data.""" digest = Bytes4(b"\xab\xcd\xef\x01") - data = Eth2Data.no_scheduled_fork(digest) + version = Bytes4(b"\x01\x00\x00\x00") + data = Eth2Data.no_scheduled_fork(digest, version) assert data.fork_digest == digest - assert data.next_fork_version == digest - assert data.next_fork_epoch == Uint64(2**64 - 1) + assert data.next_fork_version == version + assert data.next_fork_epoch == FAR_FUTURE_EPOCH def test_eth2_data_immutable(self) -> None: """Eth2Data is immutable (frozen).""" @@ -41,6 +46,10 @@ def test_eth2_data_immutable(self) -> None: with pytest.raises(ValidationError): data.fork_digest = Bytes4(b"\x00\x00\x00\x00") + def test_far_future_epoch_value(self) -> None: + """FAR_FUTURE_EPOCH is max uint64.""" + assert FAR_FUTURE_EPOCH == Uint64(2**64 - 1) + class TestAttestationSubnets: """Tests for AttestationSubnets bitvector.""" @@ -93,6 +102,39 @@ def test_invalid_subnet_id_in_is_subscribed(self) -> None: with pytest.raises(ValueError): subnets.is_subscribed(-1) + def test_from_subnet_ids_empty_list(self) -> None: + """from_subnet_ids with empty list creates no subscriptions.""" + subnets = AttestationSubnets.from_subnet_ids([]) + assert subnets.subscription_count() == 0 + assert subnets.subscribed_subnets() == [] + + def test_from_subnet_ids_with_duplicates(self) -> None: + """from_subnet_ids handles duplicates correctly.""" + subnets = AttestationSubnets.from_subnet_ids([5, 5, 5, 10]) + assert subnets.subscription_count() == 2 + assert subnets.subscribed_subnets() == [5, 10] + + def test_encode_bytes_empty(self) -> None: + """Empty subscriptions serialize to 8 zero bytes.""" + subnets = AttestationSubnets.none() + assert subnets.encode_bytes() == b"\x00" * 8 + + def test_encode_bytes_all(self) -> None: + """All subscriptions serialize to 8 0xff bytes.""" + subnets = AttestationSubnets.all() + assert subnets.encode_bytes() == b"\xff" * 8 + + def test_decode_bytes_roundtrip(self) -> None: + """Encode then decode produces equivalent result.""" + original = AttestationSubnets.from_subnet_ids([0, 5, 63]) + encoded = original.encode_bytes() + decoded = AttestationSubnets.decode_bytes(encoded) + assert decoded.subscribed_subnets() == original.subscribed_subnets() + + def test_length_constant(self) -> None: + """LENGTH constant is 64.""" + assert AttestationSubnets.LENGTH == 64 + class TestSyncCommitteeSubnets: """Tests for SyncCommitteeSubnets bitvector.""" @@ -128,3 +170,61 @@ def test_is_subscribed_raises_for_negative_id(self) -> None: subnets = SyncCommitteeSubnets.none() with pytest.raises(ValueError, match="must be 0-3"): subnets.is_subscribed(-1) + + def test_from_subnet_ids_specific(self) -> None: + """from_subnet_ids() creates specific subscriptions.""" + subnets = SyncCommitteeSubnets.from_subnet_ids([0, 2]) + assert subnets.is_subscribed(0) + assert not subnets.is_subscribed(1) + assert subnets.is_subscribed(2) + assert not subnets.is_subscribed(3) + + def test_from_subnet_ids_empty_list(self) -> None: + """from_subnet_ids with empty list creates no subscriptions.""" + subnets = SyncCommitteeSubnets.from_subnet_ids([]) + assert subnets.subscription_count() == 0 + + def test_from_subnet_ids_with_duplicates(self) -> None: + """from_subnet_ids handles duplicates correctly.""" + subnets = SyncCommitteeSubnets.from_subnet_ids([1, 1, 1, 3]) + assert subnets.subscription_count() == 2 + assert subnets.subscribed_subnets() == [1, 3] + + def test_from_subnet_ids_invalid(self) -> None: + """from_subnet_ids() raises for invalid subnet IDs.""" + with pytest.raises(ValueError, match="must be 0-3"): + SyncCommitteeSubnets.from_subnet_ids([4]) + + with pytest.raises(ValueError, match="must be 0-3"): + SyncCommitteeSubnets.from_subnet_ids([-1]) + + def test_subscribed_subnets(self) -> None: + """subscribed_subnets() returns correct list.""" + subnets = SyncCommitteeSubnets.from_subnet_ids([1, 3]) + assert subnets.subscribed_subnets() == [1, 3] + + def test_subscription_count(self) -> None: + """subscription_count() returns correct count.""" + subnets = SyncCommitteeSubnets.from_subnet_ids([0, 2, 3]) + assert subnets.subscription_count() == 3 + + def test_encode_bytes_empty(self) -> None: + """Empty subscriptions serialize to 1 zero byte.""" + subnets = SyncCommitteeSubnets.none() + assert subnets.encode_bytes() == b"\x00" + + def test_encode_bytes_all(self) -> None: + """All subscriptions serialize to 0x0f (lower 4 bits set).""" + subnets = SyncCommitteeSubnets.all() + assert subnets.encode_bytes() == b"\x0f" + + def test_decode_bytes_roundtrip(self) -> None: + """Encode then decode produces equivalent result.""" + original = SyncCommitteeSubnets.from_subnet_ids([0, 2]) + encoded = original.encode_bytes() + decoded = SyncCommitteeSubnets.decode_bytes(encoded) + assert decoded.subscribed_subnets() == original.subscribed_subnets() + + def test_length_constant(self) -> None: + """LENGTH constant is 4.""" + assert SyncCommitteeSubnets.LENGTH == 4