diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index c188abcc..cea97679 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -6,11 +6,14 @@ Usage:: python -m lean_spec --genesis genesis.json --bootnode /ip4/127.0.0.1/tcp/9000 + python -m lean_spec --genesis genesis.json --bootnode enr:-IS4QHCYrYZbAKW... + python -m lean_spec --genesis genesis.json --checkpoint-sync-url http://localhost:5052 Options: - --genesis Path to genesis JSON file (required) - --bootnode Multiaddr of bootnode to connect to (can be repeated) - --listen Address to listen on (default: /ip4/0.0.0.0/tcp/9000) + --genesis Path to genesis JSON file (required) + --bootnode Bootnode address (multiaddr or ENR string, can be repeated) + --listen Address to listen on (default: /ip4/0.0.0.0/tcp/9000) + --checkpoint-sync-url URL to fetch finalized checkpoint state for fast sync """ from __future__ import annotations @@ -20,17 +23,268 @@ import logging from pathlib import Path -from lean_spec.subspecs.containers import Checkpoint +from lean_spec.subspecs.containers import Block, BlockBody, Checkpoint, State +from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.genesis import GenesisConfig from lean_spec.subspecs.networking.client import LiveNetworkEventSource from lean_spec.subspecs.networking.reqresp.message import Status from lean_spec.subspecs.node import Node, NodeConfig +from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.types import Bytes32 logger = logging.getLogger(__name__) +def is_enr_string(bootnode: str) -> bool: + """ + Check if bootnode string is an ENR (vs multiaddr). + + Uses prefix detection rather than attempting full parsing. + This is both faster and avoids import overhead for simple checks. + + Per EIP-778, all ENR strings begin with "enr:" followed by base64url content. + """ + return bootnode.startswith("enr:") + + +def resolve_bootnode(bootnode: str) -> str: + """ + Resolve a bootnode string to a multiaddr. + + Supports both ENR and multiaddr formats for interoperability. + Different tools emit different formats: + + - Lighthouse, Prysm: Often provide ENR strings + - libp2p tools: Usually provide multiaddrs directly + + Args: + bootnode: Either an ENR string (enr:-IS4Q...) or multiaddr (/ip4/.../tcp/...). + + Returns: + Multiaddr string suitable for dialing. + + Raises: + ValueError: If ENR is malformed or has no TCP connection info. + """ + if is_enr_string(bootnode): + from lean_spec.subspecs.networking.enr import ENR + + enr = ENR.from_string(bootnode) + + # ENR.multiaddr() returns None when the record lacks IP or TCP port. + # + # This happens with discovery-only ENRs that only contain UDP info. + # We require TCP for libp2p connections. + multiaddr = enr.multiaddr() + if multiaddr is None: + raise ValueError(f"ENR has no TCP connection info: {enr}") + return multiaddr + + # Already a multiaddr string. Pass through without validation. + # + # Validation happens when dialing; early validation here would + # duplicate logic and reduce flexibility for multiaddr extensions. + return bootnode + + +def create_anchor_block(state: State) -> Block: + """ + Create an anchor block from a checkpoint state. + + The forkchoice store requires a block to establish the starting point. + We reconstruct this "anchor block" from the header embedded in the state. + + The body content does not matter for fork choice initialization. + Only header fields (slot, parent, state root) establish the anchor. + + Args: + state: The checkpoint state containing the latest block header. + + Returns: + A Block suitable for initializing the forkchoice store. + """ + header = state.latest_block_header + + # The state root in the header may be zero. + # + # Why? Block processing stores the header BEFORE computing post-state root. + # This prevents circular dependency: state root depends on header, header + # would depend on state root. The spec breaks this cycle by storing zero + # initially, then filling it in when the next slot processes. + # + # For checkpoint sync, we may receive state at exactly the block's slot. + # In this case, the state root was never filled in. We compute it now. + state_root = header.state_root + if state_root == Bytes32.zero(): + state_root = hash_tree_root(state) + + # Build a minimal body. + # + # Fork choice only cares about the block's identity (its hash) and + # lineage (parent_root). The body content is irrelevant for anchoring. + # We use an empty body because we lack the original block data. + body = BlockBody(attestations=AggregatedAttestations(data=[])) + + return Block( + slot=header.slot, + proposer_index=header.proposer_index, + parent_root=header.parent_root, + state_root=state_root, + body=body, + ) + + +def _init_from_genesis( + genesis: GenesisConfig, + event_source: LiveNetworkEventSource, +) -> Node: + """ + Initialize a node from genesis configuration. + + Args: + genesis: Genesis configuration with time and validators. + event_source: Network transport for the node. + + Returns: + A fully initialized Node starting from genesis. + """ + # Set initial status for handshakes. + # + # At genesis, our finalized and head are both the genesis block (unknown root). + genesis_status = Status( + finalized=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), + ) + event_source.set_status(genesis_status) + + # Create node configuration. + config = NodeConfig( + genesis_time=genesis.genesis_time, + validators=genesis.to_validators(), + event_source=event_source, + network=event_source.reqresp_client, + ) + + # Create and return the node. + return Node.from_genesis(config) + + +async def _init_from_checkpoint( + checkpoint_sync_url: str, + genesis: GenesisConfig, + event_source: LiveNetworkEventSource, +) -> Node | None: + """ + Initialize a node from a checkpoint state fetched from a remote node. + + Checkpoint sync trades trustlessness for speed. The node trusts the + checkpoint source to provide a valid finalized state. This is acceptable + because: + + - The state is finalized (2/3 of validators attested to it) + - Users explicitly opt in via the CLI flag + - The alternative (syncing from genesis) takes hours or days + + Processing steps: + + 1. Fetch finalized state from checkpoint URL + 2. Verify structural validity + 3. Validate genesis time matches + 4. Create anchor block + 5. Initialize forkchoice store + 6. Return configured Node + + Args: + checkpoint_sync_url: URL of the node to fetch checkpoint state from. + genesis: Local genesis configuration for validation. + event_source: Network transport for the node. + + Returns: + A fully initialized Node if successful, None if checkpoint sync failed. + """ + from lean_spec.subspecs.api.client import ( + CheckpointSyncError, + fetch_finalized_state, + verify_checkpoint_state, + ) + + try: + logger.info("Fetching checkpoint state from %s", checkpoint_sync_url) + state = await fetch_finalized_state(checkpoint_sync_url, State) + + # Structural validation catches corrupted or malformed states. + # + # This is defense in depth. We trust the source, but still verify + # basic invariants before using the state. + if not await verify_checkpoint_state(state): + logger.error("Checkpoint state verification failed") + return None + + # Genesis time MUST match. + # + # This is our only protection against syncing to a different chain. + # If genesis times differ, the checkpoint belongs to another network. + # We reject rather than risk corrupting our view of the chain. + # + # We do NOT fall back to genesis sync on failure. That would silently + # mask configuration errors and leave operators unaware their node + # started from scratch instead of the checkpoint. + if state.config.genesis_time != genesis.genesis_time: + logger.error( + "Genesis time mismatch: checkpoint=%d, local=%d", + state.config.genesis_time, + genesis.genesis_time, + ) + return None + + # Create anchor block from checkpoint state. + anchor_block = create_anchor_block(state) + + # Initialize forkchoice store from checkpoint. + # + # The store treats this as the new "genesis" for fork choice purposes. + # All blocks before the checkpoint are effectively pruned. + store = Store.get_forkchoice_store(state, anchor_block) + logger.info( + "Initialized from checkpoint at slot %d (finalized=%s)", + state.slot, + store.latest_finalized.root.hex()[:16], + ) + + # Set initial status for handshakes based on checkpoint. + checkpoint_status = Status( + finalized=store.latest_finalized, + head=Checkpoint(root=store.head, slot=store.blocks[store.head].slot), + ) + event_source.set_status(checkpoint_status) + + # Use validators from checkpoint state, not genesis. + # + # The validator set evolves over time. Deposits add validators, + # exits remove them. The checkpoint state reflects the current set. + config = NodeConfig( + genesis_time=genesis.genesis_time, + validators=state.validators, + event_source=event_source, + network=event_source.reqresp_client, + ) + + # Create node and inject checkpoint store. + # + # TODO: Add a dedicated factory method for cleaner API. + node = Node.from_genesis(config) + node.store = store + node.sync_service.store = store + + return node + + except CheckpointSyncError as e: + logger.error("Checkpoint sync failed: %s", e) + return None + + def setup_logging(verbose: bool = False) -> None: """Configure logging for the node.""" level = logging.DEBUG if verbose else logging.INFO @@ -45,6 +299,7 @@ async def run_node( genesis_path: Path, bootnodes: list[str], listen_addr: str, + checkpoint_sync_url: str | None = None, ) -> None: """ Run the lean consensus node. @@ -53,8 +308,8 @@ async def run_node( genesis_path: Path to genesis JSON file. bootnodes: List of bootnode multiaddrs to connect to. listen_addr: Address to listen on. + checkpoint_sync_url: Optional URL to fetch checkpoint state for fast sync. """ - # Load genesis configuration. logger.info("Loading genesis from %s", genesis_path) genesis = GenesisConfig.from_json_file(genesis_path) logger.info( @@ -63,49 +318,63 @@ async def run_node( len(genesis.genesis_validators), ) - # Create network transport. event_source = LiveNetworkEventSource.create() - # Create initial status for handshakes. + # Two initialization paths: checkpoint sync or genesis sync. # - # At genesis, our finalized and head are both the genesis block. - genesis_status = Status( - finalized=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), - head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), - ) - event_source.set_status(genesis_status) - - # Create node configuration. - config = NodeConfig( - genesis_time=genesis.genesis_time, - validators=genesis.to_validators(), - event_source=event_source, - network=event_source.reqresp_client, - ) + # Checkpoint sync (preferred for mainnet/testnets): + # - Downloads finalized state from trusted node + # - Skips weeks/months of historical block processing + # - Ready to participate in consensus within minutes + # + # Genesis sync (required for new networks): + # - Starts from block 0 with initial validator set + # - Must process every block to reach current head + # - Only practical for new or small networks + node: Node | None + if checkpoint_sync_url is not None: + node = await _init_from_checkpoint( + checkpoint_sync_url=checkpoint_sync_url, + genesis=genesis, + event_source=event_source, + ) + if node is None: + # Checkpoint sync failed. Exit rather than falling back. + # + # Silent fallback to genesis would surprise operators. + # They explicitly requested checkpoint sync for a reason. + return + else: + node = _init_from_genesis(genesis=genesis, event_source=event_source) - # Create the node. - node = Node.from_genesis(config) logger.info("Node initialized, peer_id=%s", event_source.connection_manager.peer_id) - # Update status with actual genesis block root. - # - # At genesis, the head and finalized are both the genesis block. - # The store.head is initialized to the genesis block root. - genesis_root = node.store.head + # Update status with actual head and finalized checkpoints. updated_status = Status( - finalized=Checkpoint(root=genesis_root, slot=Slot(0)), - head=Checkpoint(root=genesis_root, slot=Slot(0)), + finalized=node.store.latest_finalized, + head=Checkpoint(root=node.store.head, slot=node.store.blocks[node.store.head].slot), ) event_source.set_status(updated_status) # Connect to bootnodes. + # + # Best-effort connection: failures don't abort the loop. + # The node can still function if at least one bootnode connects. for bootnode in bootnodes: - logger.info("Connecting to bootnode %s", bootnode) - peer_id = await event_source.dial(bootnode) - if peer_id: - logger.info("Connected to bootnode, peer_id=%s", peer_id) - else: - logger.warning("Failed to connect to bootnode %s", bootnode) + try: + multiaddr = resolve_bootnode(bootnode) + logger.info("Connecting to bootnode %s", multiaddr) + peer_id = await event_source.dial(multiaddr) + if peer_id: + logger.info("Connected to bootnode, peer_id=%s", peer_id) + else: + logger.warning("Failed to connect to bootnode %s", multiaddr) + except ValueError as e: + # Truncate bootnode string in error logs. + # + # ENR strings can exceed 200 characters, making logs unreadable. + # First 40 chars include the "enr:" prefix and enough to identify. + logger.warning("Invalid bootnode %s: %s", bootnode[:40], e) # Start listening (in background). if listen_addr: @@ -136,13 +405,19 @@ def main() -> None: action="append", default=[], dest="bootnodes", - help="Bootnode multiaddr (can be repeated)", + help="Bootnode address (multiaddr or ENR string, can be repeated)", ) parser.add_argument( "--listen", default="/ip4/0.0.0.0/tcp/9000", help="Address to listen on (default: /ip4/0.0.0.0/tcp/9000)", ) + parser.add_argument( + "--checkpoint-sync-url", + type=str, + default=None, + help="URL to fetch finalized checkpoint state for fast sync (e.g., http://localhost:5052)", + ) parser.add_argument( "-v", "--verbose", @@ -155,7 +430,14 @@ def main() -> None: setup_logging(args.verbose) try: - asyncio.run(run_node(args.genesis, args.bootnodes, args.listen)) + asyncio.run( + run_node( + args.genesis, + args.bootnodes, + args.listen, + args.checkpoint_sync_url, + ) + ) except KeyboardInterrupt: logger.info("Shutting down...") diff --git a/src/lean_spec/subspecs/api/client.py b/src/lean_spec/subspecs/api/client.py index fdd84b50..4511d8a0 100644 --- a/src/lean_spec/subspecs/api/client.py +++ b/src/lean_spec/subspecs/api/client.py @@ -1,8 +1,19 @@ """ Checkpoint sync client for downloading finalized state from another node. -This client is used for fast synchronization - instead of syncing from genesis, -a node can download the finalized state from a trusted peer and start from there. +Checkpoint sync enables fast startup by skipping historical block processing. +Instead of replaying every block from genesis, a node downloads a recent +finalized state and starts from there. + +Trust model: + +- The operator trusts the checkpoint source to provide valid finalized state +- This trust is acceptable because finalized state has 2/3 validator support +- The alternative (genesis sync) may take hours or days on mainnet + +The trade-off is trustlessness for speed. Most operators accept this because +they already trust their checkpoint source (often their own infrastructure +or a well-known provider). """ from __future__ import annotations @@ -21,36 +32,49 @@ logger = logging.getLogger(__name__) -# Constants DEFAULT_TIMEOUT = 60.0 +"""HTTP request timeout in seconds. Large states may take time to transfer.""" + FINALIZED_STATE_ENDPOINT = "/lean/states/finalized" +"""API endpoint for fetching finalized state. Follows Beacon API conventions.""" class CheckpointSyncError(Exception): - """Error during checkpoint sync.""" + """ + Error during checkpoint sync. + + Raised when the checkpoint state cannot be fetched or is invalid. + Callers should handle this by aborting startup (not falling back). + """ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": """ Fetch finalized state from a node via checkpoint sync. - Downloads the finalized state as SSZ binary and deserializes it. + Downloads the state as SSZ binary and deserializes it. SSZ format is + preferred over JSON because state objects are large (tens of MB) and + SSZ is more compact and faster to parse. Args: - url: Base URL of the node API (e.g., "http://localhost:5052") - state_class: The State class to deserialize into + url: Base URL of the node API (e.g., "http://localhost:5052"). + state_class: The State class to deserialize into. Returns: - The finalized State object + The finalized State object. Raises: - CheckpointSyncError: If the request fails or state is invalid + CheckpointSyncError: If the request fails or state is invalid. """ base_url = url.rstrip("/") full_url = f"{base_url}{FINALIZED_STATE_ENDPOINT}" logger.info(f"Fetching finalized state from {full_url}") + # Request SSZ binary format. + # + # The Accept header tells the server we want raw bytes, not JSON. + # This is faster to transfer and parse than JSON encoding. headers = {"Accept": "application/octet-stream"} try: @@ -61,6 +85,10 @@ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": ssz_data = response.content logger.info(f"Downloaded {len(ssz_data)} bytes of SSZ state data") + # Deserialize from SSZ bytes. + # + # This validates the byte stream matches the expected schema. + # Malformed data will raise an exception here. state = state_class.decode_bytes(ssz_data) logger.info(f"Deserialized state at slot {state.slot}") @@ -80,26 +108,40 @@ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": async def verify_checkpoint_state(state: "State") -> bool: """ - Verify that a checkpoint state is valid. + Verify that a checkpoint state is structurally valid. + + This is defense-in-depth validation. We trust the checkpoint source, + but still verify basic invariants before using the state. These checks + catch corrupted downloads or misconfigured servers. + + The checks are intentionally minimal: + + - Slot is non-negative (sanity check) + - Validators exist (empty state is useless) + - Validator count within limits (prevents DoS) - Performs basic validation checks on the downloaded state. + We do NOT verify cryptographic proofs here. That would require + the full block history, defeating the purpose of checkpoint sync. Args: - state: The state to verify + state: The state to verify. Returns: - True if valid, False otherwise + True if valid, False otherwise. """ try: + # Sanity check: slot must be non-negative. if state.slot < Slot(0): logger.error("Invalid state: negative slot") return False + # A state with no validators cannot produce blocks. validator_count = len(state.validators) if validator_count == 0: logger.error("Invalid state: no validators") return False + # Guard against oversized states that could exhaust memory. if validator_count > int(DEVNET_CONFIG.validator_registry_limit): logger.error( f"Invalid state: validator count {validator_count} exceeds " @@ -107,6 +149,10 @@ async def verify_checkpoint_state(state: "State") -> bool: ) return False + # Compute state root to verify SSZ deserialization worked correctly. + # + # If the data was corrupted, hashing will likely fail or produce + # an unexpected result. We log the root for debugging. state_root = hash_tree_root(state) root_preview = state_root.hex()[:16] logger.info(f"Checkpoint state verified: slot={state.slot}, root={root_preview}...") diff --git a/src/lean_spec/subspecs/networking/client/event_source.py b/src/lean_spec/subspecs/networking/client/event_source.py index 84d35a23..82f108ee 100644 --- a/src/lean_spec/subspecs/networking/client/event_source.py +++ b/src/lean_spec/subspecs/networking/client/event_source.py @@ -5,12 +5,97 @@ network connections. It bridges the gap between the low-level transport layer (ConnectionManager + yamux) and the high-level sync service. -Event Flow + +WHY THIS MODULE EXISTS +---------------------- +The sync service operates at a high level of abstraction. It thinks in +terms of "block arrived" or "peer connected" events. The transport layer +operates at the byte level: TCP streams, encrypted frames, multiplexed +channels. This module translates between these worlds. + + +EVENT FLOW ---------- -1. ConnectionManager establishes connections (Noise + yamux) -2. LiveNetworkEventSource monitors connections for activity -3. Incoming messages are parsed and converted to NetworkEvent objects -4. NetworkService consumes events via async iteration +Messages flow through the system in stages: + +1. ConnectionManager establishes connections (Noise + yamux). +2. LiveNetworkEventSource monitors connections for activity. +3. Incoming messages are parsed and converted to NetworkEvent objects. +4. NetworkService consumes events via async iteration. + + +GOSSIP MESSAGE FLOW +------------------- +When a peer publishes a block or attestation, it arrives as follows: + +1. Peer opens a yamux stream with protocol ID "/meshsub/1.1.0". +2. Peer sends: [topic_length][topic][data_length][compressed_data]. +3. We parse the topic to determine message type (block vs attestation). +4. We decompress the Snappy-framed payload. +5. We decode the SSZ bytes into a typed object. +6. We emit a GossipBlockEvent or GossipAttestationEvent. + + +GOSSIP MESSAGE FORMAT +--------------------- +Incoming gossip messages arrive on yamux streams with the gossipsub protocol ID. +The message format is: + ++------------------+---------------------------------------------+ +| Field | Description | ++==================+=============================================+ +| topic_length | Varint: byte length of the topic string | ++------------------+---------------------------------------------+ +| topic | UTF-8 string identifying message type | ++------------------+---------------------------------------------+ +| data_length | Varint: byte length of compressed data | ++------------------+---------------------------------------------+ +| data | Snappy-framed SSZ-encoded message | ++------------------+---------------------------------------------+ + +Varints use LEB128 encoding (1-10 bytes depending on value). +Most lengths fit in 1-2 bytes since messages are typically under 16KB. + + +MESSAGE DEDUPLICATION +--------------------- +Gossipsub uses message IDs to prevent duplicate delivery. The Ethereum +consensus spec defines message ID as: + + message_id = SHA256(MESSAGE_DOMAIN + topic_length + topic + data)[:20] + +MESSAGE_DOMAIN is 0x00 for invalid Snappy, 0x01 for valid Snappy. This +domain separation ensures a message cannot be "replayed" by flipping +between compressed and raw forms. + + +WHY SSZ AND SNAPPY? +------------------- +SSZ (Simple Serialize) is Ethereum's canonical serialization format: + +- Deterministic: Same object always produces same bytes. +- Merkleizable: Supports efficient proofs of inclusion. +- Fixed overhead: Known sizes enable buffer pre-allocation. + +Snappy compression reduces bandwidth by 50-70% for typical blocks. +The framing format adds CRC32C checksums for corruption detection. + + +GOSSIPSUB v1.1 REQUIREMENTS +--------------------------- +The Ethereum consensus spec requires gossipsub v1.1 (protocol "/meshsub/1.1.0"). +Key v1.1 features used: + +- Peer scoring: Misbehaving peers get lower scores. +- Extended validators: Message validation before forwarding. +- Flood publishing: High-priority messages bypass mesh constraints. + + +References: + - Ethereum P2P spec: https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/p2p-interface.md + - Gossipsub v1.1: https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md + - SSZ spec: https://github.com/ethereum/consensus-specs/blob/dev/ssz/simple-serialize.md + - Snappy framing: https://github.com/google/snappy/blob/master/framing_format.txt """ from __future__ import annotations @@ -19,10 +104,17 @@ import logging from dataclasses import dataclass, field +from lean_spec.snappy import SnappyDecompressionError, frame_decompress from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.containers.attestation import SignedAttestation from lean_spec.subspecs.networking.config import GOSSIPSUB_DEFAULT_PROTOCOL_ID from lean_spec.subspecs.networking.gossipsub.topic import GossipTopic, TopicKind +from lean_spec.subspecs.networking.reqresp.handler import ( + REQRESP_PROTOCOL_IDS, + BlockLookup, + DefaultRequestHandler, + ReqRespServer, +) from lean_spec.subspecs.networking.reqresp.message import Status from lean_spec.subspecs.networking.service.events import ( GossipAttestationEvent, @@ -37,13 +129,294 @@ ConnectionManager, YamuxConnection, ) +from lean_spec.subspecs.networking.transport.connection.types import Stream +from lean_spec.subspecs.networking.varint import VarintError +from lean_spec.subspecs.networking.varint import decode as decode_varint from lean_spec.subspecs.networking.varint import encode as encode_varint +from lean_spec.types.exceptions import SSZSerializationError from .reqresp_client import ReqRespClient logger = logging.getLogger(__name__) +class GossipMessageError(Exception): + """Raised when a gossip message cannot be processed. + + This error wraps underlying failures (topic parsing, decompression, + SSZ decoding) into a single type for cleaner error handling at the + stream processing level. + """ + + +@dataclass(slots=True) +class GossipHandler: + """ + Handles incoming gossip messages from peers. + + Parses gossip message format, decompresses Snappy, decodes SSZ, and + returns the appropriate decoded object. + + Supported topic kinds: + + - Block: Decodes to SignedBlockWithAttestation + - Attestation: Decodes to SignedAttestation + + + WHY TOPIC VALIDATION? + --------------------- + Topics contain: + + - Fork digest: 4-byte identifier derived from genesis + fork version. + - Message type: "block" or "attestation". + - Encoding: Always "ssz_snappy" for Ethereum. + + Validating the topic prevents: + + - Routing attacks: Reject messages for different forks. + - Type confusion: Ensure we decode with the correct schema. + - Protocol violations: Reject malformed topic strings. + + + WHY SNAPPY? + ----------- + Snappy reduces bandwidth by 50-70% for typical consensus messages. + Beacon blocks contain many signatures and hashes which compress well. + The framing format adds CRC32C checksums for corruption detection. + + + WHY SSZ? + -------- + SSZ (Simple Serialize) is Ethereum's canonical format because: + + - Deterministic: Same object always produces same bytes. + - Merkleizable: Efficient proofs of inclusion. + - Schema-driven: Type information comes from context, not wire format. + + The topic tells us the schema. The SSZ bytes are just raw data. + """ + + fork_digest: str + """Expected fork digest for topic validation. + + Messages with mismatched fork digests are rejected. This prevents + cross-fork message injection attacks. + """ + + def decode_message( + self, + topic_str: str, + compressed_data: bytes, + ) -> SignedBlockWithAttestation | SignedAttestation: + """ + Decode a gossip message from topic and compressed data. + + Processing proceeds in order: + + 1. Parse topic to determine message type. + 2. Decompress Snappy-framed data. + 3. Decode SSZ bytes using the appropriate schema. + + Each step can fail independently. Failures are wrapped in + GossipMessageError for uniform handling. + + Args: + topic_str: Full topic string (e.g., "/leanconsensus/0x.../block/ssz_snappy"). + compressed_data: Snappy-compressed SSZ data. + + Returns: + Decoded block or attestation. + + Raises: + GossipMessageError: If the message cannot be decoded. + """ + # Step 1: Parse topic to determine message type. + # + # The topic string contains the fork digest and message kind. + # Invalid topics are rejected before any decompression work. + # This prevents wasting CPU on malformed messages. + try: + topic = GossipTopic.from_string(topic_str) + except ValueError as e: + raise GossipMessageError(f"Invalid topic: {e}") from e + + # Step 2: Decompress Snappy-framed data. + # + # Snappy framing splits data into 64KB chunks with CRC32C checksums. + # Decompression fails if: + # - Stream identifier is missing or invalid. + # - Chunk CRC doesn't match (corruption detected). + # - Chunk size exceeds 64KB limit. + # + # Failed decompression indicates network corruption or a malicious peer. + try: + ssz_bytes = frame_decompress(compressed_data) + except SnappyDecompressionError as e: + raise GossipMessageError(f"Snappy decompression failed: {e}") from e + + # Step 3: Decode SSZ based on topic kind. + # + # SSZ decoding fails if the bytes don't match the expected schema. + # For example: wrong length, invalid field values, or truncation. + # + # The topic determines which schema to use. This is why topic + # validation must happen first. + try: + match topic.kind: + case TopicKind.BLOCK: + return SignedBlockWithAttestation.decode_bytes(ssz_bytes) + case TopicKind.ATTESTATION: + return SignedAttestation.decode_bytes(ssz_bytes) + except SSZSerializationError as e: + raise GossipMessageError(f"SSZ decode failed: {e}") from e + + def get_topic(self, topic_str: str) -> GossipTopic: + """ + Parse and validate a topic string. + + Args: + topic_str: Full topic string. + + Returns: + Parsed GossipTopic. + + Raises: + GossipMessageError: If the topic is invalid. + """ + try: + return GossipTopic.from_string(topic_str) + except ValueError as e: + raise GossipMessageError(f"Invalid topic: {e}") from e + + +async def read_gossip_message(stream: Stream) -> tuple[str, bytes]: + """ + Read a gossip message from a yamux stream. + + Gossip message wire format:: + + [topic_len: varint][topic: UTF-8][data_len: varint][data: bytes] + + Args: + stream: Yamux stream to read from. + + Returns: + Tuple of (topic_string, compressed_data). + + Raises: + GossipMessageError: If the message format is invalid. + + + WHY VARINTS? + ------------ + Varints (LEB128 encoding) use 1 byte for values 0-127, 2 bytes for + 128-16383, etc. Since topic lengths are typically ~50 bytes and data + lengths under 1MB, varints save bandwidth compared to fixed-width integers. + + The libp2p gossipsub wire format uses varints throughout. + + + WHY INCREMENTAL PARSING? + ------------------------ + Varints have variable length. We cannot know how many bytes to read + for the topic length until we try to decode it. The incremental + approach: + + 1. Read available data into buffer. + 2. Try to parse varint. If not enough bytes, read more. + 3. Once varint is complete, read the indicated payload. + 4. Repeat for data length and data payload. + + This handles network fragmentation gracefully. Data may arrive in + arbitrary chunks due to TCP buffering and yamux framing. + + + EDGE CASES HANDLED + ------------------ + - Truncated varint: VarintError raised, we keep reading. + - Truncated topic/data: Loop continues until complete. + - Empty message: Caught before any parsing. + - Invalid UTF-8 topic: GossipMessageError raised. + - Stream closes early: GossipMessageError with "Truncated" message. + """ + # Accumulate data in a buffer. + # + # Network data arrives in arbitrary chunks. We need to buffer until + # we have complete fields. A bytearray is efficient for appending. + buffer = bytearray() + + # Read and parse incrementally. + # + # The outer loop reads chunks from the network. + # The inner parsing attempts to extract fields from the buffer. + # We only return once we have a complete message. + while True: + chunk = await stream.read() + if not chunk: + # Stream closed. If buffer is empty, peer sent nothing. + # If buffer has data, the message is incomplete. + if not buffer: + raise GossipMessageError("Empty gossip message") + break + buffer.extend(chunk) + + # Attempt to parse the accumulated data. + # + # Parsing can fail partway through if we don't have enough bytes. + # In that case, we continue the outer loop to read more data. + try: + # Parse topic length varint. + # + # The varint tells us how many bytes the topic string occupies. + # Most topics are ~50 bytes, so this is typically a 1-byte varint. + topic_len, topic_len_bytes = decode_varint(bytes(buffer), 0) + topic_end = topic_len_bytes + topic_len + + if len(buffer) >= topic_end: + # We have the complete topic string. + # + # Topics are UTF-8 encoded. Invalid encoding indicates + # a protocol violation or corrupted data. + topic_str = buffer[topic_len_bytes:topic_end].decode("utf-8") + + if len(buffer) > topic_end: + # Parse data length varint. + # + # This tells us how many bytes of compressed data follow. + # Block messages can be several hundred KB compressed. + data_len, data_len_bytes = decode_varint(bytes(buffer), topic_end) + data_start = topic_end + data_len_bytes + data_end = data_start + data_len + + if len(buffer) >= data_end: + # We have the complete message. + # + # Extract the compressed data and return. + # The caller will decompress and decode. + compressed_data = bytes(buffer[data_start:data_end]) + return topic_str, compressed_data + + except VarintError: + # Varint is incomplete (truncated in the middle). + # + # This is normal - we may have read only part of a varint. + # Continue reading more data from the stream. + continue + + except UnicodeDecodeError as e: + # Topic bytes are not valid UTF-8. + # + # This indicates a protocol violation or corruption. + # Fail immediately rather than trying to recover. + raise GossipMessageError(f"Invalid topic encoding: {e}") from e + + # Loop exited without returning a complete message. + # + # The stream closed before we received all expected data. + # This could be a network failure or peer misbehavior. + raise GossipMessageError("Truncated gossip message") + + @dataclass(slots=True) class LiveNetworkEventSource: """ @@ -53,34 +426,119 @@ class LiveNetworkEventSource: Bridges the transport layer (ConnectionManager) to the event-driven sync layer. - Responsibilities + + ARCHITECTURE + ------------ + This class sits between two layers:: + + Transport Layer (low-level) + | + LiveNetworkEventSource <-- This class + | + NetworkService (high-level) + + The transport layer deals with bytes, streams, and connections. + The sync layer deals with blocks, attestations, and peer status. + This class translates between them. + + + RESPONSIBILITIES ---------------- - - Accept incoming connections and emit PeerConnectedEvent - - Dial outbound connections and emit PeerConnectedEvent - - Exchange Status messages and emit PeerStatusEvent - - Publish locally-produced blocks and attestations to the gossip network + - Accept incoming connections and emit PeerConnectedEvent. + - Dial outbound connections and emit PeerConnectedEvent. + - Exchange Status messages and emit PeerStatusEvent. + - Receive gossip messages and emit GossipBlockEvent/GossipAttestationEvent. + - Publish locally-produced blocks and attestations. + + + CONCURRENCY MODEL + ----------------- + Each connection spawns a background task that accepts incoming streams. + Each gossip stream spawns its own task to read the message. + + This allows concurrent handling of multiple peers and messages. + The event queue serializes delivery to the consumer. + + + BACKPRESSURE + ------------ + The event queue provides natural backpressure. If the consumer is + slow, the queue grows. Eventually, async iteration semantics cause + producers to wait. """ connection_manager: ConnectionManager - """Underlying transport manager.""" + """Underlying transport manager. + + Handles the full connection stack: TCP, Noise encryption, yamux multiplexing. + """ reqresp_client: ReqRespClient - """Client for req/resp protocol operations.""" + """Client for req/resp protocol operations. + + Used for Status exchange and block/attestation requests. + """ _events: asyncio.Queue[NetworkEvent] = field(default_factory=asyncio.Queue) - """Queue of pending events to yield.""" + """Queue of pending events to yield. + + Events are produced by background tasks and consumed via async iteration. + """ _connections: dict[PeerId, YamuxConnection] = field(default_factory=dict) - """Active connections by peer ID.""" + """Active connections by peer ID. + + Used to route outbound messages and track peer state. + """ _our_status: Status | None = None - """Our current chain status for handshakes.""" + """Our current chain status for handshakes. + + Contains our finalized checkpoint and head. Exchanged with peers on connect. + """ _fork_digest: str = "0x00000000" - """Fork digest for gossip topics.""" + """Fork digest for gossip topics. + + 4-byte identifier derived from genesis validators root and fork version. + Used to validate incoming messages belong to the same fork. + """ _running: bool = False - """Whether the event source is running.""" + """Whether the event source is running. + + Controls the main loop and background tasks. + """ + + _gossip_handler: GossipHandler = field(init=False) + """Handler for decoding incoming gossip messages. + + Initialized with the current fork digest. + """ + + _gossip_tasks: set[asyncio.Task[None]] = field(default_factory=set) + """Background tasks processing incoming gossip streams. + + Tracked for cleanup on shutdown. Tasks remove themselves on completion. + """ + + _reqresp_handler: DefaultRequestHandler = field(init=False) + """Handler for inbound ReqResp requests. + + Provides chain data to peers requesting Status or BlocksByRoot. + """ + + _reqresp_server: ReqRespServer = field(init=False) + """Server for processing inbound ReqResp streams. + + Routes requests to the appropriate handler method. + """ + + def __post_init__(self) -> None: + """Initialize handlers with current configuration.""" + object.__setattr__(self, "_gossip_handler", GossipHandler(fork_digest=self._fork_digest)) + object.__setattr__(self, "_reqresp_handler", DefaultRequestHandler()) + object.__setattr__(self, "_reqresp_server", ReqRespServer(handler=self._reqresp_handler)) @classmethod def create( @@ -108,12 +566,15 @@ def create( def set_status(self, status: Status) -> None: """ - Set our chain status for handshakes. + Set our chain status for handshakes and inbound Status requests. + + Updates both the outbound status exchange and the inbound request handler. Args: status: Our current finalized and head checkpoints. """ self._our_status = status + self._reqresp_handler.our_status = status def set_fork_digest(self, fork_digest: str) -> None: """ @@ -123,6 +584,19 @@ def set_fork_digest(self, fork_digest: str) -> None: fork_digest: 4-byte fork identifier as hex string. """ self._fork_digest = fork_digest + object.__setattr__(self, "_gossip_handler", GossipHandler(fork_digest=fork_digest)) + + def set_block_lookup(self, lookup: BlockLookup) -> None: + """ + Set the callback for looking up blocks by root. + + Used by the inbound ReqResp handler to serve BlocksByRoot requests. + + Args: + lookup: Async function that takes a Bytes32 root and returns + the SignedBlockWithAttestation if available, None otherwise. + """ + self._reqresp_handler.block_lookup = lookup def __aiter__(self) -> LiveNetworkEventSource: """Return self as async iterator.""" @@ -171,6 +645,11 @@ async def dial(self, multiaddr: str) -> PeerId | None: # Exchange status. await self._exchange_status(peer_id, conn) + # Start background task to accept incoming streams. + task = asyncio.create_task(self._accept_streams(peer_id, conn)) + self._gossip_tasks.add(task) + task.add_done_callback(self._gossip_tasks.discard) + logger.info("Connected to peer %s at %s", peer_id, multiaddr) return peer_id @@ -213,6 +692,11 @@ async def _handle_inbound_connection(self, conn: YamuxConnection) -> None: # Exchange status. await self._exchange_status(peer_id, conn) + # Start background task to accept incoming streams. + task = asyncio.create_task(self._accept_streams(peer_id, conn)) + self._gossip_tasks.add(task) + task.add_done_callback(self._gossip_tasks.discard) + logger.info("Accepted connection from peer %s", peer_id) async def _exchange_status( @@ -260,8 +744,10 @@ async def disconnect(self, peer_id: PeerId) -> None: logger.info("Disconnected from peer %s", peer_id) def stop(self) -> None: - """Stop the event source.""" + """Stop the event source and cancel background tasks.""" self._running = False + for task in self._gossip_tasks: + task.cancel() async def _emit_gossip_block( self, @@ -295,6 +781,220 @@ async def _emit_gossip_attestation( GossipAttestationEvent(attestation=attestation, peer_id=peer_id, topic=topic) ) + async def _accept_streams(self, peer_id: PeerId, conn: YamuxConnection) -> None: + """ + Accept incoming streams from a connection. + + Runs in the background, accepting streams and dispatching them to + the appropriate handler based on protocol ID. + + Args: + peer_id: Peer that owns the connection. + conn: Yamux connection to accept streams from. + + + WHY BACKGROUND STREAM ACCEPTANCE? + --------------------------------- + Yamux multiplexing allows peers to open many streams concurrently. + Each stream is an independent request/response conversation. + + Running stream acceptance in the background allows: + + - Concurrent handling of multiple incoming streams. + - Non-blocking connection management. + - Graceful handling of peer disconnection. + + Without background acceptance, the main event loop would block + waiting for streams from one peer while ignoring others. + + + PROTOCOL ID ROUTING + ------------------- + The protocol ID (from multistream-select negotiation) determines + how to handle the stream: + + - "/meshsub/1.1.0": Gossipsub message (block or attestation). + - Other protocols: Req/resp handled elsewhere; close unknown. + + This routing happens at the stream level, not the message level. + Each protocol has its own message format and semantics. + """ + try: + # Main loop: accept streams until shutdown or disconnection. + # + # The loop continues as long as: + # - We haven't been told to stop (_running is True). + # - The peer is still connected (peer_id in _connections). + while self._running and peer_id in self._connections: + try: + # Accept the next incoming stream. + # + # This blocks until a peer opens a stream or the connection closes. + # Yamux handles the low-level multiplexing. + stream = await conn.accept_stream() + except Exception as e: + # Connection closed or other transport error. + # + # This is expected when the peer disconnects. + # Exit the loop cleanly rather than propagating. + logger.debug("Stream accept failed for %s: %s", peer_id, e) + break + + # Route the stream based on its negotiated protocol. + # + # The protocol ID was determined during multistream-select + # when the peer opened the stream. It tells us what kind + # of message to expect. + protocol_id = stream.protocol_id + + if protocol_id == GOSSIPSUB_DEFAULT_PROTOCOL_ID: + # Gossipsub stream: contains a block or attestation. + # + # Handle in a separate task to avoid blocking stream acceptance. + # This allows processing multiple gossip messages concurrently. + task = asyncio.create_task(self._handle_gossip_stream(peer_id, stream)) + self._gossip_tasks.add(task) + task.add_done_callback(self._gossip_tasks.discard) + + elif protocol_id in REQRESP_PROTOCOL_IDS: + # ReqResp stream: Status or BlocksByRoot request. + # + # Handle in a separate task to allow concurrent request processing. + # The ReqRespServer handles decoding, dispatching, and responding. + task = asyncio.create_task( + self._reqresp_server.handle_stream(stream, protocol_id) + ) + self._gossip_tasks.add(task) + task.add_done_callback(self._gossip_tasks.discard) + logger.debug("Handling ReqResp %s from %s", protocol_id, peer_id) + + else: + # Unknown protocol. + # + # Close the stream gracefully. The peer may be running + # a newer client with protocols we don't support. + logger.debug( + "Unknown protocol %s from %s, closing stream", protocol_id, peer_id + ) + await stream.close() + + except asyncio.CancelledError: + # Task was cancelled during shutdown. + # + # This is normal cleanup behavior. Log and exit. + logger.debug("Stream acceptor cancelled for %s", peer_id) + + except Exception as e: + # Unexpected error. + # + # Log as warning since this may indicate a bug. + # The connection will be cleaned up elsewhere. + logger.warning("Stream acceptor error for %s: %s", peer_id, e) + + async def _handle_gossip_stream(self, peer_id: PeerId, stream: Stream) -> None: + """ + Handle an incoming gossip stream. + + Reads the gossip message, decodes it, and emits the appropriate event. + + Args: + peer_id: Peer that sent the message. + stream: Yamux stream containing the gossip message. + + + COMPLETE FLOW + ------------- + A gossip message goes through these stages: + + 1. Read raw bytes from yamux stream. + 2. Parse topic string and data length (varints). + 3. Decompress Snappy-framed data. + 4. Decode SSZ bytes into typed object. + 5. Emit event to the sync layer. + + Any stage can fail. Failures are logged but don't crash the handler. + + + ERROR HANDLING STRATEGY + ----------------------- + Gossip is best-effort. A single bad message should not: + + - Crash the node. + - Disconnect the peer. + - Block other messages. + + We log errors and continue. Peer scoring (not implemented here) + would track repeated failures for reputation management. + + + RESOURCE CLEANUP + ---------------- + The stream MUST be closed in finally, even if errors occur. + Unclosed streams leak yamux resources and can cause deadlocks. + """ + try: + # Step 1: Read the gossip message from the stream. + # + # This parses the varint-prefixed topic and data fields. + # May fail if the message is truncated or malformed. + topic_str, compressed_data = await read_gossip_message(stream) + + # Step 2: Decode the message. + # + # This performs: + # - Topic validation (correct prefix, encoding, fork). + # - Snappy decompression with CRC verification. + # - SSZ decoding into the appropriate type. + message = self._gossip_handler.decode_message(topic_str, compressed_data) + topic = self._gossip_handler.get_topic(topic_str) + + # Step 3: Emit the appropriate event based on message type. + # + # The topic determines the expected message type. + # We verify the decoded type matches to catch bugs. + match topic.kind: + case TopicKind.BLOCK: + if isinstance(message, SignedBlockWithAttestation): + await self._emit_gossip_block(message, peer_id) + else: + # Type mismatch indicates a bug in decode_message. + logger.warning("Block topic but got %s", type(message).__name__) + + case TopicKind.ATTESTATION: + if isinstance(message, SignedAttestation): + await self._emit_gossip_attestation(message, peer_id) + else: + # Type mismatch indicates a bug in decode_message. + logger.warning("Attestation topic but got %s", type(message).__name__) + + logger.debug("Received gossip %s from %s", topic.kind.value, peer_id) + + except GossipMessageError as e: + # Expected error: malformed message, decompression failure, etc. + # + # This is not necessarily a bug. The peer may be misbehaving + # or there may be network corruption. Log and continue. + logger.warning("Gossip message error from %s: %s", peer_id, e) + + except Exception as e: + # Unexpected error: likely a bug in our code. + # + # Log as warning to aid debugging. Don't crash. + logger.warning("Unexpected error handling gossip from %s: %s", peer_id, e) + + finally: + # Always close the stream to release yamux resources. + # + # Unclosed streams cause resource leaks and can deadlock + # the connection if too many accumulate. + # + # The try/except suppresses close errors. The stream may + # already be closed if the connection dropped. + try: + await stream.close() + except Exception: + pass + async def publish(self, topic: str, data: bytes) -> None: """ Broadcast a message to all connected peers on a topic. diff --git a/src/lean_spec/subspecs/networking/reqresp/__init__.py b/src/lean_spec/subspecs/networking/reqresp/__init__.py index f983a142..07011ed3 100644 --- a/src/lean_spec/subspecs/networking/reqresp/__init__.py +++ b/src/lean_spec/subspecs/networking/reqresp/__init__.py @@ -6,6 +6,15 @@ decode_request, encode_request, ) +from .handler import ( + REQRESP_PROTOCOL_IDS, + BlockLookup, + DefaultRequestHandler, + ReqRespServer, + RequestHandler, + ResponseStream, + YamuxResponseStream, +) from .message import ( BLOCKS_BY_ROOT_PROTOCOL_V1, STATUS_PROTOCOL_V1, @@ -18,6 +27,7 @@ # Protocol IDs "BLOCKS_BY_ROOT_PROTOCOL_V1", "STATUS_PROTOCOL_V1", + "REQRESP_PROTOCOL_IDS", # Message types "BlocksByRootRequest", "BlocksByRootResponse", @@ -27,4 +37,11 @@ "ResponseCode", "encode_request", "decode_request", + # Inbound handlers + "BlockLookup", + "DefaultRequestHandler", + "RequestHandler", + "ReqRespServer", + "ResponseStream", + "YamuxResponseStream", ] diff --git a/src/lean_spec/subspecs/networking/reqresp/handler.py b/src/lean_spec/subspecs/networking/reqresp/handler.py new file mode 100644 index 00000000..3a957196 --- /dev/null +++ b/src/lean_spec/subspecs/networking/reqresp/handler.py @@ -0,0 +1,588 @@ +""" +Inbound ReqResp protocol handlers. + +This module handles incoming peer requests in the Ethereum consensus protocol. +A peer opens a stream, sends a request, and expects one or more response chunks. + + +WHY INBOUND AND OUTBOUND ARE SEPARATE +------------------------------------- +Ethereum's req/resp protocol is asymmetric: + +- Outbound: We initiate. We choose what to ask. +- Inbound: Peer initiates. We must respond correctly. + +The flows mirror each other but have different responsibilities: + + Outbound: open_stream -> encode_request -> write -> read -> decode_response + Inbound: accept_stream -> decode_request -> handle -> encode_response -> write + +Keeping them separate makes each flow easier to understand and test. + + +WHY HANDLERS USE ResponseStream ABSTRACTION +------------------------------------------- +Handlers receive a ResponseStream instead of a raw transport stream. +This design provides three benefits: + +1. Testability: Unit tests provide mock streams without network I/O. +2. Flexibility: Different transports (yamux, memory, etc.) work with the same handlers. +3. Clarity: Handlers focus on protocol logic, not wire format encoding. + +The ResponseStream translates high-level operations (send success, send error) into the +wire format defined in codec.py. + + +WIRE FORMAT +----------- +All responses use the same wire format from codec.py: + + [response_code: 1 byte][varint: uncompressed_length][snappy_framed_payload] + +Response codes: + +- 0 (SUCCESS): Payload contains SSZ-encoded response data +- 1 (INVALID_REQUEST): Peer sent malformed or invalid request +- 2 (SERVER_ERROR): Internal error during processing +- 3 (RESOURCE_UNAVAILABLE): Requested data not found + +Error payloads contain UTF-8 encoded human-readable messages. + + +PROTOCOL IDENTIFIERS +-------------------- +Each request type has a unique protocol ID negotiated via multistream-select: + +- Status: "/leanconsensus/req/status/1/ssz_snappy" +- BlocksByRoot: "/leanconsensus/req/blocks_by_root/1/ssz_snappy" + +The protocol ID determines: + +- Which SSZ type to deserialize the request into +- Which handler processes the request +- What response type(s) the peer expects + + +References: + Ethereum P2P spec: + https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/p2p-interface.md + Wire format details: + See codec.py in this package +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Protocol + +from lean_spec.subspecs.containers import SignedBlockWithAttestation +from lean_spec.subspecs.networking.transport.connection.types import Stream +from lean_spec.types import Bytes32 + +from .codec import CodecError, ResponseCode, decode_request +from .message import ( + BLOCKS_BY_ROOT_PROTOCOL_V1, + STATUS_PROTOCOL_V1, + BlocksByRootRequest, + Status, +) + +logger = logging.getLogger(__name__) + +REQUEST_TIMEOUT_SECONDS: float = 10.0 +"""Default timeout for processing inbound requests.""" + + +class ResponseStream(Protocol): + """ + Protocol for sending chunked responses to peers. + + Abstracts the underlying stream transport, allowing handlers to send + responses without knowing the wire format details. + + Response Types + -------------- + - Success: Contains SSZ-encoded response data. + - Error: Contains UTF-8 error message. + + Both types are encoded using the same wire format from codec.py. + """ + + async def send_success(self, ssz_data: bytes) -> None: + """ + Send a SUCCESS response chunk. + + Args: + ssz_data: SSZ-encoded response payload. + """ + ... + + async def send_error(self, code: ResponseCode, message: str) -> None: + """ + Send an error response and close the stream. + + Args: + code: Error code (INVALID_REQUEST, SERVER_ERROR, RESOURCE_UNAVAILABLE). + message: Human-readable error description. + """ + ... + + async def finish(self) -> None: + """ + Signal end of response stream. + + Called after all response chunks have been sent. + Closes the stream gracefully. + """ + ... + + +@dataclass(slots=True) +class YamuxResponseStream: + """ + ResponseStream implementation wrapping a yamux stream. + + Encodes responses using the wire format from codec.py and writes + them to the underlying stream. + """ + + _stream: Stream + """Underlying yamux stream.""" + + async def send_success(self, ssz_data: bytes) -> None: + """ + Send a SUCCESS response chunk. + + Args: + ssz_data: SSZ-encoded response payload. + """ + # Encode the response using the protocol wire format. + # + # ResponseCode.SUCCESS (0x00) tells the peer this chunk contains valid data. + # The encode method handles: + # + # 1. Prepending the response code byte + # 2. Adding the varint length prefix + # 3. Compressing with Snappy framing + encoded = ResponseCode.SUCCESS.encode(ssz_data) + await self._stream.write(encoded) + + async def send_error(self, code: ResponseCode, message: str) -> None: + """ + Send an error response. + + Args: + code: Error code. + message: Human-readable error description. + """ + # Error messages must be UTF-8 encoded per the Ethereum P2P spec. + # + # The spec mandates UTF-8 for interoperability across clients. + # Common error codes: + # + # - INVALID_REQUEST (1): Malformed request, bad SSZ, protocol violation + # - SERVER_ERROR (2): Internal failure, handler exception + # - RESOURCE_UNAVAILABLE (3): Block/blob not found + encoded = code.encode(message.encode("utf-8")) + await self._stream.write(encoded) + + async def finish(self) -> None: + """Close the stream gracefully.""" + await self._stream.close() + + +class RequestHandler(ABC): + """ + Abstract base for request handlers. + + Implementations provide the logic for responding to specific request types. + The sync service or network layer implements this to provide chain data. + + + HANDLER CONTRACT + ---------------- + Handlers MUST: + + - Send at least one response (success or error) via ResponseStream. + - Not raise exceptions (errors should be sent as error responses). + - Be idempotent (same request may arrive multiple times). + + + CONCURRENCY + ----------- + Handlers may be called concurrently for different requests. + Implementations should be thread-safe if accessing shared state. + """ + + @abstractmethod + async def handle_status(self, request: Status, response: ResponseStream) -> None: + """ + Handle incoming Status request. + + The handler should respond with our current chain status. + + Args: + request: Peer's status message. + response: Stream for sending our status response. + """ + ... + + @abstractmethod + async def handle_blocks_by_root( + self, + request: BlocksByRootRequest, + response: ResponseStream, + ) -> None: + """ + Handle incoming BlocksByRoot request. + + The handler should send each requested block as a separate response chunk. + Blocks we do not have should be skipped (or RESOURCE_UNAVAILABLE sent). + + Args: + request: List of block roots being requested. + response: Stream for sending block responses. + """ + ... + + +BlockLookup = Callable[[Bytes32], Awaitable[SignedBlockWithAttestation | None]] +"""Type alias for block lookup function. + +Takes a block root and returns the block if available, None otherwise. +""" + + +@dataclass(slots=True) +class DefaultRequestHandler(RequestHandler): + """ + Default request handler implementation. + + Uses callbacks to retrieve chain data. + Suitable for use with NetworkEventSource. + + + STATUS HANDLING + --------------- + Returns our current status, which must be set via our_status field. + If no status is set, responds with SERVER_ERROR. + + + BLOCKS BY ROOT HANDLING + ----------------------- + Looks up each requested block via the block_lookup callback. + Available blocks are sent as SUCCESS chunks. + Unavailable blocks are silently skipped (per Ethereum P2P spec). + """ + + our_status: Status | None = None + """Our current chain status for Status responses.""" + + block_lookup: BlockLookup | None = None + """Callback to look up blocks by root.""" + + async def handle_status(self, request: Status, response: ResponseStream) -> None: + """ + Handle incoming Status request. + + Responds with our current chain status. + + Args: + request: Peer's status (logged but not used for response). + response: Stream for sending our status. + """ + # Guard: Ensure we have a status configured. + # + # This can happen during node startup before sync completes. + if self.our_status is None: + logger.warning("Status request received but no status configured") + await response.send_error(ResponseCode.SERVER_ERROR, "Status not available") + return + + # Respond with OUR status, not the peer's. + # + # The Status exchange is symmetric: each side sends its own chain state. + # The peer's status (in `request`) is useful for: + # + # - Logging for debugging + # - Peer scoring (handled elsewhere) + # - Fork detection (handled by sync layer) + # + # But it does NOT affect what we respond with. + # We always send our current head and finalized checkpoint. + await response.send_success(self.our_status.encode_bytes()) + + async def handle_blocks_by_root( + self, + request: BlocksByRootRequest, + response: ResponseStream, + ) -> None: + """ + Handle incoming BlocksByRoot request. + + Looks up and sends each requested block. + + Args: + request: Block roots to look up. + response: Stream for sending blocks. + """ + # Guard: Ensure we have a block lookup configured. + if self.block_lookup is None: + logger.warning("BlocksByRoot request received but no block_lookup configured") + await response.send_error(ResponseCode.SERVER_ERROR, "Block lookup not available") + return + + # Process each requested block root. + # + # Key design decisions per Ethereum P2P spec: + # + # 1. Missing blocks are SKIPPED, not errors. + # Peers expect partial responses. They track which roots they received. + # + # 2. Lookup errors are LOGGED and SKIPPED. + # One failed lookup should not prevent returning other blocks. + # + # 3. Order is preserved. + # Blocks are sent in the same order as requested. + for root in request.data: + try: + block = await self.block_lookup(root) + if block is not None: + await response.send_success(block.encode_bytes()) + + # Missing block: Skip silently. + # + # The spec allows partial responses. + # Peers handle missing blocks by requesting from other peers. + # Sending RESOURCE_UNAVAILABLE for each missing block would be noisy. + except Exception as e: + # Lookup error: Log and continue. + # + # Database errors, timeouts, etc. should not abort the response. + # The peer can retry or ask another peer for this specific block. + logger.warning("Error looking up block %s: %s", root.hex()[:8], e) + + +REQRESP_PROTOCOL_IDS: frozenset[str] = frozenset( + { + STATUS_PROTOCOL_V1, + BLOCKS_BY_ROOT_PROTOCOL_V1, + } +) +"""Protocol IDs handled by ReqRespServer.""" + + +@dataclass(slots=True) +class ReqRespServer: + """ + Server for handling inbound ReqResp streams. + + Routes incoming requests to the appropriate handler based on protocol ID. + Handles decoding, dispatching, and error handling. + + + STREAM LIFECYCLE + ---------------- + For each incoming stream: + + 1. Read all request data from stream. + 2. Decode the request (remove length prefix, decompress Snappy). + 3. Deserialize SSZ bytes to the appropriate type. + 4. Dispatch to handler. + 5. Handler sends response(s) via ResponseStream. + 6. Close stream. + + + ERROR HANDLING + -------------- + Errors at any stage result in an error response: + + - Malformed request: INVALID_REQUEST + - Decode failure: INVALID_REQUEST + - Handler error: SERVER_ERROR + """ + + handler: RequestHandler + """Handler for processing requests.""" + + _pending_data: dict[int, bytearray] = field(default_factory=dict) + """Buffer for accumulating request data by stream ID. + + Request data may arrive in multiple chunks. We accumulate until + the stream closes, then process the complete request. + """ + + async def handle_stream(self, stream: Stream, protocol_id: str) -> None: + """ + Handle an incoming ReqResp stream. + + Reads the request, decodes it, and dispatches to the appropriate handler. + + Args: + stream: Incoming yamux stream. + protocol_id: Negotiated protocol ID. + """ + response = YamuxResponseStream(_stream=stream) + + try: + # Step 1: Read the complete request before processing. + # + # Why read everything first? + # + # 1. Simplicity: No streaming parser needed for small requests. + # 2. Validation: Can reject malformed requests before touching state. + # 3. Atomic handling: Either process the full request or reject it. + # + # Request sizes are bounded by MAX_PAYLOAD_SIZE (10 MiB) in codec.py. + # Typical sizes: Status ~100 bytes, BlocksByRoot ~1KB. + data = await self._read_request(stream) + if not data: + await response.send_error(ResponseCode.INVALID_REQUEST, "Empty request") + return + + # Step 2: Decode wire format to raw SSZ bytes. + # + # The wire format wraps SSZ with: + # + # - Varint length prefix (for buffer allocation) + # - Snappy framing (for compression and checksums) + # + # Decoding validates the length matches and checksums pass. + try: + ssz_bytes = decode_request(data) + except CodecError as e: + # Wire format error: malformed varint, bad Snappy, length mismatch. + # + # This is INVALID_REQUEST because the peer sent bad data. + logger.debug("Request decode error: %s", e) + await response.send_error(ResponseCode.INVALID_REQUEST, str(e)) + return + + # Step 3: Dispatch based on protocol ID. + # + # The protocol ID was negotiated via multistream-select before + # this stream was created. It tells us: + # + # - What SSZ type to deserialize into + # - Which handler processes the request + await self._dispatch(protocol_id, ssz_bytes, response) + + except Exception as e: + # Catch-all for unexpected errors. + # + # Any exception reaching here indicates a bug or system failure. + # Send SERVER_ERROR so the peer knows we had an internal problem. + # The peer may retry or try another node. + logger.warning("Unexpected error handling request: %s", e) + try: + await response.send_error(ResponseCode.SERVER_ERROR, "Internal error") + except Exception: + # Write failed. Nothing more we can do. + pass + + finally: + # Always close the stream. + # + # This runs regardless of success or failure. + # Closing signals to the peer that the response is complete. + try: + await response.finish() + except Exception: + # Close failed. Log is unnecessary - peer will timeout. + pass + + async def _read_request(self, stream: Stream) -> bytes: + """ + Read all request data from a stream. + + Accumulates chunks until the stream closes. + + Args: + stream: Stream to read from. + + Returns: + Complete request data. + """ + buffer = bytearray() + while True: + chunk = await stream.read() + if not chunk: + break + buffer.extend(chunk) + return bytes(buffer) + + async def _dispatch( + self, + protocol_id: str, + ssz_bytes: bytes, + response: ResponseStream, + ) -> None: + """ + Dispatch a request to the appropriate handler. + + Args: + protocol_id: Protocol ID identifying the request type. + ssz_bytes: SSZ-encoded request payload. + response: Stream for sending responses. + """ + # Dispatch pattern: Protocol ID determines handler. + # + # Each protocol ID maps to: + # + # 1. An SSZ type for deserialization + # 2. A handler method to process the request + # + # Adding a new request type requires: + # + # - Define the SSZ types in message.py + # - Add the protocol ID constant + # - Add a handler method to RequestHandler + # - Add a branch here + + if protocol_id == STATUS_PROTOCOL_V1: + # Status request: Peer wants our chain state. + # + # SSZ decoding validates: + # + # - Correct size (80 bytes for Status) + # - Valid field offsets + try: + request = Status.decode_bytes(ssz_bytes) + except Exception as e: + # SSZ decode failure: wrong size, malformed offsets, etc. + # + # This is INVALID_REQUEST - the peer sent bad SSZ. + logger.debug("Status decode error: %s", e) + await response.send_error(ResponseCode.INVALID_REQUEST, "Invalid Status message") + return + await self.handler.handle_status(request, response) + + elif protocol_id == BLOCKS_BY_ROOT_PROTOCOL_V1: + # BlocksByRoot request: Peer wants specific blocks by hash. + # + # The request is an SSZ list of 32-byte roots. + # Length must be a multiple of 32 bytes. + try: + request = BlocksByRootRequest.decode_bytes(ssz_bytes) + except Exception as e: + # SSZ decode failure: wrong size, not multiple of 32, etc. + logger.debug("BlocksByRootRequest decode error: %s", e) + await response.send_error( + ResponseCode.INVALID_REQUEST, "Invalid BlocksByRootRequest message" + ) + return + await self.handler.handle_blocks_by_root(request, response) + + else: + # Unknown protocol ID. + # + # This should not happen in normal operation. + # The transport layer filters streams by REQRESP_PROTOCOL_IDS. + # + # If we reach here, it indicates a bug in stream routing. + # Use SERVER_ERROR because this is our fault, not the peer's. + logger.warning("Unknown protocol: %s", protocol_id) + await response.send_error(ResponseCode.SERVER_ERROR, "Unknown protocol") diff --git a/src/lean_spec/subspecs/networking/transport/connection/manager.py b/src/lean_spec/subspecs/networking/transport/connection/manager.py index 83092a6a..2d218592 100644 --- a/src/lean_spec/subspecs/networking/transport/connection/manager.py +++ b/src/lean_spec/subspecs/networking/transport/connection/manager.py @@ -210,6 +210,23 @@ async def open_stream(self, protocol: str) -> Stream: return yamux_stream + async def accept_stream(self) -> Stream: + """ + Accept an incoming stream from the peer. + + Blocks until a new stream is opened by the remote side. + + Returns: + New stream opened by peer. + + Raises: + TransportConnectionError: If connection is closed. + """ + if self._closed: + raise TransportConnectionError("Connection is closed") + + return await self._yamux.accept_stream() + async def close(self) -> None: """Close the connection gracefully.""" if self._closed: diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index c4c6734f..3cd3c3a1 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -12,12 +12,29 @@ This service drives validator duties by monitoring the slot clock and triggering production at the appropriate intervals. +Proposer Attestation Design +--------------------------- +Each validator attests exactly once per slot. However, proposers and +non-proposers attest at different times: + +- Proposers attest at interval 0, bundled inside their block +- Non-proposers attest at interval 1, broadcast separately + +This design has two benefits: + +1. Proposers see their own attestation immediately (no network delay) +2. Non-proposers can attest to a block they actually received + +The proposer's attestation is embedded in `BlockWithAttestation` alongside +the block itself. At interval 1, we skip proposers because they already +attested. This prevents double-attestation. + How It Works ------------ 1. Sleep until next interval boundary 2. Check if any validator we control has duties 3. For interval 0: Check proposer schedule, produce block if our turn -4. For interval 1: Produce attestations for all our validators +4. For interval 1: Produce attestations for all non-proposer validators 5. Emit produced blocks/attestations via callbacks 6. Repeat forever """ @@ -151,6 +168,10 @@ async def _maybe_produce_block(self, slot: Slot) -> None: Checks the proposer schedule against our validator registry. If one of our validators should propose, produces and emits the block. + The proposer's attestation is bundled into the block rather than + broadcast separately at interval 1. This ensures the proposer's vote + is included without network round-trip delays. + Args: slot: Current slot number. """ @@ -168,9 +189,14 @@ async def _maybe_produce_block(self, slot: Slot) -> None: if not is_proposer(validator_index, slot, num_validators): continue - # We are the proposer. + # We are the proposer for this slot. # - # Produce the block using Store's production method. + # Block production includes two steps: + # 1. Create the block with aggregated attestations from the pool + # 2. Sign and bundle our own attestation into BlockWithAttestation + # + # Our attestation goes in the block envelope, not the body. + # This separates "attestations we're including" from "our own vote". try: new_store, block, signatures = store.produce_block_with_signatures( slot=slot, @@ -183,6 +209,8 @@ async def _maybe_produce_block(self, slot: Slot) -> None: self.sync_service.store = new_store # Create signed block wrapper for publishing. + # + # This adds our attestation and signatures to the block. signed_block = self._sign_block(block, validator_index, signatures) self._blocks_produced += 1 metrics.blocks_proposed.inc() @@ -201,16 +229,31 @@ async def _maybe_produce_block(self, slot: Slot) -> None: async def _produce_attestations(self, slot: Slot) -> None: """ - Produce attestations for all validators we control. + Produce attestations for all non-proposer validators we control. - Every validator should attest once per slot. + Every validator attests exactly once per slot. Since proposers already + bundled their attestation inside the block at interval 0, they are + skipped here to prevent double-attestation. Args: slot: Current slot number. """ store = self.sync_service.store + head_state = store.states.get(store.head) + if head_state is None: + return + + num_validators = Uint64(len(head_state.validators)) for validator_index in self.registry.indices(): + # Skip proposer - they already attested within their block. + # + # The proposer signed and bundled their attestation at interval 0. + # Creating another attestation here would violate the + # "one attestation per validator per slot" invariant. + if is_proposer(validator_index, slot, num_validators): + continue + # Produce attestation data using Store's method. # # This calculates head, target, and source checkpoints. diff --git a/tests/lean_spec/subspecs/networking/client/__init__.py b/tests/lean_spec/subspecs/networking/client/__init__.py new file mode 100644 index 00000000..95c7a81b --- /dev/null +++ b/tests/lean_spec/subspecs/networking/client/__init__.py @@ -0,0 +1 @@ +"""Tests for networking client module.""" diff --git a/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py b/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py new file mode 100644 index 00000000..19602692 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py @@ -0,0 +1,699 @@ +"""Tests for gossip message reception functionality. + +This module tests the GossipHandler class, GossipMessageError exception, +and read_gossip_message async function that handle incoming gossip messages +from peers in the P2P network. + +Gossip message format: +- Topic length (varint) +- Topic string (UTF-8) +- Data length (varint) +- Data (Snappy-compressed SSZ) +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from lean_spec.snappy import frame_compress +from lean_spec.subspecs.containers import SignedBlockWithAttestation +from lean_spec.subspecs.containers.attestation import SignedAttestation +from lean_spec.subspecs.containers.checkpoint import Checkpoint +from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.networking.client.event_source import ( + GossipHandler, + GossipMessageError, + read_gossip_message, +) +from lean_spec.subspecs.networking.gossipsub.topic import ( + ENCODING_POSTFIX, + TOPIC_PREFIX, + GossipTopic, + TopicKind, +) +from lean_spec.subspecs.networking.varint import encode as encode_varint +from lean_spec.types import Bytes32, Uint64 +from tests.lean_spec.helpers.builders import make_signed_attestation, make_signed_block + +# ============================================================================= +# Test Fixtures and Helpers +# ============================================================================= + + +class MockStream: + """ + A mock stream for testing read_gossip_message. + + Simulates a yamux stream by returning data in chunks. + """ + + def __init__(self, data: bytes, chunk_size: int = 1024) -> None: + """ + Initialize the mock stream. + + Args: + data: Complete data to return from reads. + chunk_size: Maximum bytes per read call. + """ + self.data = data + self.chunk_size = chunk_size + self.offset = 0 + + @property + def protocol_id(self) -> str: + """Return a mock protocol ID.""" + return "/meshsub/1.1.0" + + async def read(self, n: int = -1) -> bytes: + """ + Read data from the mock stream. + + Args: + n: Ignored, uses chunk_size instead. + + Returns: + Next chunk of data, or empty bytes if exhausted. + """ + if self.offset >= len(self.data): + return b"" + end = min(self.offset + self.chunk_size, len(self.data)) + chunk = self.data[self.offset : end] + self.offset = end + return chunk + + async def write(self, data: bytes) -> None: + """Mock write (not used in reception tests).""" + pass + + async def close(self) -> None: + """Mock close.""" + pass + + async def reset(self) -> None: + """Mock reset.""" + pass + + +def make_block_topic(fork_digest: str = "0x00000000") -> str: + """Create a valid block topic string.""" + return f"/{TOPIC_PREFIX}/{fork_digest}/block/{ENCODING_POSTFIX}" + + +def make_attestation_topic(fork_digest: str = "0x00000000") -> str: + """Create a valid attestation topic string.""" + return f"/{TOPIC_PREFIX}/{fork_digest}/attestation/{ENCODING_POSTFIX}" + + +def make_test_signed_block() -> SignedBlockWithAttestation: + """Create a minimal signed block for testing.""" + return make_signed_block( + slot=Slot(1), + proposer_index=Uint64(0), + parent_root=Bytes32.zero(), + state_root=Bytes32.zero(), + ) + + +def make_test_signed_attestation() -> SignedAttestation: + """Create a minimal signed attestation for testing.""" + return make_signed_attestation( + validator=Uint64(0), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), + ) + + +def build_gossip_message(topic: str, ssz_data: bytes) -> bytes: + """ + Build a complete gossip message from topic and SSZ data. + + Format: [topic_len varint][topic][data_len varint][compressed_data] + """ + topic_bytes = topic.encode("utf-8") + compressed_data = frame_compress(ssz_data) + + message = bytearray() + message.extend(encode_varint(len(topic_bytes))) + message.extend(topic_bytes) + message.extend(encode_varint(len(compressed_data))) + message.extend(compressed_data) + + return bytes(message) + + +# ============================================================================= +# Tests for GossipMessageError +# ============================================================================= + + +class TestGossipMessageError: + """Tests for the GossipMessageError exception.""" + + def test_is_exception_subclass(self) -> None: + """GossipMessageError inherits from Exception.""" + assert issubclass(GossipMessageError, Exception) + + def test_message_preserved(self) -> None: + """Error message is preserved.""" + msg = "Test error message" + error = GossipMessageError(msg) + assert str(error) == msg + + def test_can_be_raised_and_caught(self) -> None: + """Can be raised and caught properly.""" + with pytest.raises(GossipMessageError, match="specific error"): + raise GossipMessageError("specific error") + + +# ============================================================================= +# Tests for GossipHandler.get_topic() +# ============================================================================= + + +class TestGossipHandlerGetTopic: + """Tests for GossipHandler.get_topic() method.""" + + def test_valid_block_topic(self) -> None: + """Parses valid block topic string.""" + handler = GossipHandler(fork_digest="0x12345678") + topic_str = "/leanconsensus/0x12345678/block/ssz_snappy" + + topic = handler.get_topic(topic_str) + + assert isinstance(topic, GossipTopic) + assert topic.kind == TopicKind.BLOCK + assert topic.fork_digest == "0x12345678" + + def test_valid_attestation_topic(self) -> None: + """Parses valid attestation topic string.""" + handler = GossipHandler(fork_digest="0x00000000") + topic_str = "/leanconsensus/0x00000000/attestation/ssz_snappy" + + topic = handler.get_topic(topic_str) + + assert isinstance(topic, GossipTopic) + assert topic.kind == TopicKind.ATTESTATION + assert topic.fork_digest == "0x00000000" + + def test_invalid_topic_format_missing_parts(self) -> None: + """Raises GossipMessageError for topic with missing parts.""" + handler = GossipHandler(fork_digest="0x00000000") + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.get_topic("/invalid/topic") + + def test_invalid_topic_format_wrong_prefix(self) -> None: + """Raises GossipMessageError for wrong network prefix.""" + handler = GossipHandler(fork_digest="0x00000000") + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.get_topic("/wrongprefix/0x00000000/block/ssz_snappy") + + def test_invalid_topic_format_wrong_encoding(self) -> None: + """Raises GossipMessageError for wrong encoding suffix.""" + handler = GossipHandler(fork_digest="0x00000000") + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.get_topic("/leanconsensus/0x00000000/block/ssz") + + def test_invalid_topic_format_unknown_topic_name(self) -> None: + """Raises GossipMessageError for unknown topic name.""" + handler = GossipHandler(fork_digest="0x00000000") + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.get_topic("/leanconsensus/0x00000000/unknown/ssz_snappy") + + def test_empty_topic_string(self) -> None: + """Raises GossipMessageError for empty topic string.""" + handler = GossipHandler(fork_digest="0x00000000") + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.get_topic("") + + +# ============================================================================= +# Tests for GossipHandler.decode_message() +# ============================================================================= + + +class TestGossipHandlerDecodeMessage: + """Tests for GossipHandler.decode_message() method.""" + + def test_decode_valid_block_message(self) -> None: + """Decodes valid block message correctly.""" + handler = GossipHandler(fork_digest="0x00000000") + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + compressed = frame_compress(ssz_bytes) + topic_str = make_block_topic() + + result = handler.decode_message(topic_str, compressed) + + assert isinstance(result, SignedBlockWithAttestation) + + def test_decode_valid_attestation_message(self) -> None: + """Decodes valid attestation message correctly.""" + handler = GossipHandler(fork_digest="0x00000000") + attestation = make_test_signed_attestation() + ssz_bytes = attestation.encode_bytes() + compressed = frame_compress(ssz_bytes) + topic_str = make_attestation_topic() + + result = handler.decode_message(topic_str, compressed) + + assert isinstance(result, SignedAttestation) + + def test_decode_invalid_topic_format(self) -> None: + """Raises GossipMessageError for invalid topic format.""" + handler = GossipHandler(fork_digest="0x00000000") + compressed = frame_compress(b"\x00" * 32) + + with pytest.raises(GossipMessageError, match="Invalid topic"): + handler.decode_message("/bad/topic", compressed) + + def test_decode_invalid_snappy_compression(self) -> None: + """Raises GossipMessageError for invalid Snappy data.""" + handler = GossipHandler(fork_digest="0x00000000") + topic_str = make_block_topic() + invalid_snappy = b"\x00\x01\x02\x03" # Not valid Snappy framed data + + with pytest.raises(GossipMessageError, match="Snappy decompression failed"): + handler.decode_message(topic_str, invalid_snappy) + + def test_decode_invalid_ssz_encoding(self) -> None: + """Raises GossipMessageError for invalid SSZ data.""" + handler = GossipHandler(fork_digest="0x00000000") + topic_str = make_block_topic() + # Valid Snappy wrapping garbage SSZ + compressed = frame_compress(b"\xff\xff\xff\xff") + + with pytest.raises(GossipMessageError, match="SSZ decode failed"): + handler.decode_message(topic_str, compressed) + + def test_decode_empty_snappy_data(self) -> None: + """Raises GossipMessageError for empty compressed data.""" + handler = GossipHandler(fork_digest="0x00000000") + topic_str = make_block_topic() + + with pytest.raises(GossipMessageError, match="Snappy decompression failed"): + handler.decode_message(topic_str, b"") + + def test_decode_truncated_ssz_data(self) -> None: + """Raises GossipMessageError for truncated SSZ data.""" + handler = GossipHandler(fork_digest="0x00000000") + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + truncated = ssz_bytes[:10] # Truncate SSZ data + compressed = frame_compress(truncated) + topic_str = make_block_topic() + + with pytest.raises(GossipMessageError, match="SSZ decode failed"): + handler.decode_message(topic_str, compressed) + + +# ============================================================================= +# Tests for read_gossip_message() +# ============================================================================= + + +class TestReadGossipMessage: + """Tests for the read_gossip_message async function.""" + + def test_read_valid_block_message(self) -> None: + """Reads valid block message from stream.""" + + async def run() -> tuple[str, bytes]: + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + topic_str = make_block_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + stream = MockStream(message_data) + return await read_gossip_message(stream) + + topic, compressed = asyncio.run(run()) + topic_str = make_block_topic() + + assert topic == topic_str + assert len(compressed) > 0 + + def test_read_valid_attestation_message(self) -> None: + """Reads valid attestation message from stream.""" + + async def run() -> tuple[str, bytes]: + attestation = make_test_signed_attestation() + ssz_bytes = attestation.encode_bytes() + topic_str = make_attestation_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + stream = MockStream(message_data) + return await read_gossip_message(stream) + + topic, compressed = asyncio.run(run()) + topic_str = make_attestation_topic() + + assert topic == topic_str + assert len(compressed) > 0 + + def test_read_empty_stream(self) -> None: + """Raises GossipMessageError for empty stream.""" + + async def run() -> tuple[str, bytes]: + stream = MockStream(b"") + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Empty gossip message"): + asyncio.run(run()) + + def test_read_truncated_topic_length(self) -> None: + """Raises GossipMessageError for incomplete topic length varint.""" + + async def run() -> tuple[str, bytes]: + # A varint byte with continuation bit set but no following bytes + incomplete_varint = b"\x80" # Continuation bit set, needs more bytes + stream = MockStream(incomplete_varint) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Truncated gossip message"): + asyncio.run(run()) + + def test_read_truncated_topic_string(self) -> None: + """Raises GossipMessageError for truncated topic string.""" + + async def run() -> tuple[str, bytes]: + topic = make_block_topic() + topic_bytes = topic.encode("utf-8") + # Claim topic is 100 bytes but only provide partial data + truncated = encode_varint(100) + topic_bytes[:10] + stream = MockStream(truncated) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Truncated gossip message"): + asyncio.run(run()) + + def test_read_truncated_data_length(self) -> None: + """Raises GossipMessageError for truncated data length varint.""" + + async def run() -> tuple[str, bytes]: + topic = make_block_topic() + topic_bytes = topic.encode("utf-8") + # Complete topic but incomplete data length varint + data = encode_varint(len(topic_bytes)) + topic_bytes + b"\x80" + stream = MockStream(data) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Truncated gossip message"): + asyncio.run(run()) + + def test_read_truncated_data(self) -> None: + """Raises GossipMessageError for truncated message data.""" + + async def run() -> tuple[str, bytes]: + topic = make_block_topic() + topic_bytes = topic.encode("utf-8") + compressed = frame_compress(b"test data") + # Claim data is 1000 bytes but only provide partial + data = ( + encode_varint(len(topic_bytes)) + topic_bytes + encode_varint(1000) + compressed[:5] + ) + stream = MockStream(data) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Truncated gossip message"): + asyncio.run(run()) + + def test_read_invalid_utf8_topic(self) -> None: + """Raises GossipMessageError for invalid UTF-8 in topic.""" + + async def run() -> tuple[str, bytes]: + # Invalid UTF-8 sequence + invalid_utf8 = b"\xff\xfe" + data = encode_varint(len(invalid_utf8)) + invalid_utf8 + # Add data portion + data += encode_varint(4) + b"test" + stream = MockStream(data) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Invalid topic encoding"): + asyncio.run(run()) + + def test_read_small_chunks(self) -> None: + """Successfully reads message delivered in small chunks.""" + + async def run() -> tuple[str, bytes]: + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + topic_str = make_block_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + # Use tiny chunks to test incremental parsing + stream = MockStream(message_data, chunk_size=5) + return await read_gossip_message(stream) + + topic, compressed = asyncio.run(run()) + topic_str = make_block_topic() + + assert topic == topic_str + assert len(compressed) > 0 + + def test_read_large_message(self) -> None: + """Successfully reads larger gossip message.""" + + async def run() -> tuple[str, bytes, bytes]: + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + topic_str = make_block_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + stream = MockStream(message_data) + topic, compressed = await read_gossip_message(stream) + return topic, compressed, ssz_bytes + + topic, compressed, ssz_bytes = asyncio.run(run()) + topic_str = make_block_topic() + + assert topic == topic_str + # Verify the compressed data can be decompressed + from lean_spec.snappy import frame_decompress + + decompressed = frame_decompress(compressed) + assert decompressed == ssz_bytes + + def test_read_single_byte_chunks(self) -> None: + """Successfully reads message with single-byte chunks.""" + + async def run() -> tuple[str, bytes]: + attestation = make_test_signed_attestation() + ssz_bytes = attestation.encode_bytes() + topic_str = make_attestation_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + # Single byte at a time - stress test incremental parsing + stream = MockStream(message_data, chunk_size=1) + return await read_gossip_message(stream) + + topic, _ = asyncio.run(run()) + topic_str = make_attestation_topic() + + assert topic == topic_str + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestGossipReceptionIntegration: + """Integration tests for the complete gossip reception flow.""" + + def test_full_block_reception_flow(self) -> None: + """Tests complete flow: stream -> parse -> decompress -> decode.""" + + async def run() -> tuple[SignedBlockWithAttestation | SignedAttestation, bytes]: + handler = GossipHandler(fork_digest="0x00000000") + original_block = make_test_signed_block() + ssz_bytes = original_block.encode_bytes() + topic_str = make_block_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + # Step 1: Read from stream + stream = MockStream(message_data) + parsed_topic, compressed = await read_gossip_message(stream) + + # Step 2: Decode message + decoded = handler.decode_message(parsed_topic, compressed) + + return decoded, original_block.encode_bytes() + + decoded, original_bytes = asyncio.run(run()) + + # Step 3: Verify result + assert isinstance(decoded, SignedBlockWithAttestation) + assert decoded.encode_bytes() == original_bytes + + def test_full_attestation_reception_flow(self) -> None: + """Tests complete flow for attestation messages.""" + + async def run() -> tuple[SignedBlockWithAttestation | SignedAttestation, bytes, TopicKind]: + handler = GossipHandler(fork_digest="0x00000000") + original_attestation = make_test_signed_attestation() + ssz_bytes = original_attestation.encode_bytes() + topic_str = make_attestation_topic() + message_data = build_gossip_message(topic_str, ssz_bytes) + + # Step 1: Read from stream + stream = MockStream(message_data) + parsed_topic, compressed = await read_gossip_message(stream) + + # Step 2: Get topic info + topic = handler.get_topic(parsed_topic) + + # Step 3: Decode message + decoded = handler.decode_message(parsed_topic, compressed) + + return decoded, original_attestation.encode_bytes(), topic.kind + + decoded, original_bytes, topic_kind = asyncio.run(run()) + + # Step 4: Verify result + assert topic_kind == TopicKind.ATTESTATION + assert isinstance(decoded, SignedAttestation) + assert decoded.encode_bytes() == original_bytes + + def test_handler_fork_digest_stored(self) -> None: + """Handler stores fork digest for topic validation.""" + digest = "0xaabbccdd" + handler = GossipHandler(fork_digest=digest) + assert handler.fork_digest == digest + + def test_roundtrip_preserves_data_integrity(self) -> None: + """Data integrity preserved through encode-compress-stream-decompress-decode.""" + + async def run() -> tuple[bytes, bytes]: + handler = GossipHandler(fork_digest="0x00000000") + original = make_test_signed_block() + original_bytes = original.encode_bytes() + + # Encode and compress + topic_str = make_block_topic() + message_data = build_gossip_message(topic_str, original_bytes) + + # Simulate network transfer via stream + stream = MockStream(message_data) + _, compressed = await read_gossip_message(stream) + + # Decode + decoded = handler.decode_message(topic_str, compressed) + decoded_bytes = decoded.encode_bytes() + + return decoded_bytes, original_bytes + + decoded_bytes, original_bytes = asyncio.run(run()) + + # Verify exact match + assert decoded_bytes == original_bytes + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestGossipReceptionEdgeCases: + """Edge case tests for gossip reception.""" + + def test_handler_with_different_fork_digests(self) -> None: + """Handler works with various fork digest formats.""" + for digest in ["0x00000000", "0xffffffff", "0x12345678", "0xabcdef01"]: + handler = GossipHandler(fork_digest=digest) + topic_str = f"/{TOPIC_PREFIX}/{digest}/block/{ENCODING_POSTFIX}" + topic = handler.get_topic(topic_str) + assert topic.fork_digest == digest + + def test_zero_length_compressed_data(self) -> None: + """Handles message with zero-length data field.""" + + async def run() -> tuple[str, bytes]: + topic = make_block_topic() + topic_bytes = topic.encode("utf-8") + # Zero-length data + data = encode_varint(len(topic_bytes)) + topic_bytes + encode_varint(0) + stream = MockStream(data) + return await read_gossip_message(stream) + + topic_result, compressed = asyncio.run(run()) + topic = make_block_topic() + assert topic_result == topic + assert compressed == b"" + + def test_decode_corrupted_snappy_crc(self) -> None: + """Detects CRC corruption in Snappy framed data.""" + handler = GossipHandler(fork_digest="0x00000000") + block = make_test_signed_block() + ssz_bytes = block.encode_bytes() + compressed = bytearray(frame_compress(ssz_bytes)) + + # Corrupt the CRC (located after stream identifier and chunk header) + if len(compressed) > 14: + compressed[14] ^= 0xFF + + topic_str = make_block_topic() + + with pytest.raises(GossipMessageError, match="Snappy decompression failed"): + handler.decode_message(topic_str, bytes(compressed)) + + def test_very_long_topic_string(self) -> None: + """Handles messages with unusually long topic strings.""" + + async def run() -> str: + # Create a long but valid-format topic + long_digest = "0x" + "a" * 100 + topic = f"/{TOPIC_PREFIX}/{long_digest}/block/{ENCODING_POSTFIX}" + topic_bytes = topic.encode("utf-8") + compressed = frame_compress(b"test") + + data = encode_varint(len(topic_bytes)) + topic_bytes + data += encode_varint(len(compressed)) + compressed + + stream = MockStream(data) + parsed_topic, _ = await read_gossip_message(stream) + return parsed_topic + + parsed_topic = asyncio.run(run()) + long_digest = "0x" + "a" * 100 + expected_topic = f"/{TOPIC_PREFIX}/{long_digest}/block/{ENCODING_POSTFIX}" + + assert parsed_topic == expected_topic + + @pytest.mark.parametrize( + "invalid_data", + [ + b"\x00", # Just a zero byte (topic length 0) + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80", # Overlong varint + ], + ids=["zero_byte_topic_length", "overlong_varint"], + ) + def test_malformed_varint_data(self, invalid_data: bytes) -> None: + """Handles various malformed varint patterns.""" + + async def run() -> tuple[str, bytes]: + stream = MockStream(invalid_data) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError): + asyncio.run(run()) + + def test_topic_only_message_missing_data(self) -> None: + """Raises error when message has topic but no data section.""" + + async def run() -> tuple[str, bytes]: + topic = make_block_topic() + topic_bytes = topic.encode("utf-8") + # Only topic, no data length or data + data = encode_varint(len(topic_bytes)) + topic_bytes + stream = MockStream(data) + return await read_gossip_message(stream) + + with pytest.raises(GossipMessageError, match="Truncated gossip message"): + asyncio.run(run()) diff --git a/tests/lean_spec/subspecs/networking/reqresp/__init__.py b/tests/lean_spec/subspecs/networking/reqresp/__init__.py new file mode 100644 index 00000000..13e47708 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/reqresp/__init__.py @@ -0,0 +1 @@ +"""Tests for ReqResp protocol handlers.""" diff --git a/tests/lean_spec/subspecs/networking/reqresp/test_handler.py b/tests/lean_spec/subspecs/networking/reqresp/test_handler.py new file mode 100644 index 00000000..db9f7e29 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/reqresp/test_handler.py @@ -0,0 +1,1565 @@ +"""Tests for inbound ReqResp protocol handlers.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +from dataclasses import dataclass, field +from typing import TypeVar + +from lean_spec.subspecs.containers import Checkpoint, SignedBlockWithAttestation +from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.networking.reqresp.codec import ( + ResponseCode, + encode_request, +) +from lean_spec.subspecs.networking.reqresp.handler import ( + REQRESP_PROTOCOL_IDS, + BlockLookup, + DefaultRequestHandler, + ReqRespServer, + YamuxResponseStream, +) +from lean_spec.subspecs.networking.reqresp.message import ( + BLOCKS_BY_ROOT_PROTOCOL_V1, + STATUS_PROTOCOL_V1, + BlocksByRootRequest, + Status, +) +from lean_spec.types import Bytes32, Uint64 +from tests.lean_spec.helpers import make_signed_block + +_T = TypeVar("_T") + +# ----------------------------------------------------------------------------- +# Mock Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class MockStream: + """Mock yamux stream for testing ReqRespServer.""" + + request_data: bytes = b"" + """Data to return when read() is called.""" + + written: list[bytes] = field(default_factory=list) + """Accumulator for data written to the stream.""" + + closed: bool = False + """Whether close() has been called.""" + + _read_offset: int = 0 + """Internal offset for simulating chunked reads.""" + + @property + def protocol_id(self) -> str: + """Mock protocol ID.""" + return STATUS_PROTOCOL_V1 + + async def read(self, n: int = -1) -> bytes: + """ + Return request data in a single chunk, then empty bytes. + + Simulates the stream EOF behavior. + """ + if self._read_offset >= len(self.request_data): + return b"" + chunk = self.request_data[self._read_offset :] + self._read_offset = len(self.request_data) + return chunk + + async def write(self, data: bytes) -> None: + """Accumulate written data for inspection.""" + self.written.append(data) + + async def close(self) -> None: + """Mark stream as closed.""" + self.closed = True + + async def reset(self) -> None: + """Abort the stream.""" + self.closed = True + + +@dataclass +class MockResponseStream: + """Mock ResponseStream for testing handlers in isolation.""" + + successes: list[bytes] = field(default_factory=list) + """SSZ data sent via send_success.""" + + errors: list[tuple[ResponseCode, str]] = field(default_factory=list) + """Errors sent via send_error as (code, message) tuples.""" + + finished: bool = False + """Whether finish() was called.""" + + async def send_success(self, ssz_data: bytes) -> None: + """Record a success response.""" + self.successes.append(ssz_data) + + async def send_error(self, code: ResponseCode, message: str) -> None: + """Record an error response.""" + self.errors.append((code, message)) + + async def finish(self) -> None: + """Mark stream as finished.""" + self.finished = True + + +# ----------------------------------------------------------------------------- +# Test Helpers +# ----------------------------------------------------------------------------- + + +def make_test_status() -> Status: + """Create a valid Status message for testing.""" + return Status( + finalized=Checkpoint(root=Bytes32(b"\x01" * 32), slot=Slot(100)), + head=Checkpoint(root=Bytes32(b"\x02" * 32), slot=Slot(200)), + ) + + +def make_test_block(slot: int = 1, seed: int = 0) -> SignedBlockWithAttestation: + """Create a valid SignedBlockWithAttestation for testing.""" + return make_signed_block( + slot=Slot(slot), + proposer_index=Uint64(0), + parent_root=Bytes32(bytes([seed]) * 32), + state_root=Bytes32(bytes([seed + 1]) * 32), + ) + + +def run_async(coro: Coroutine[object, object, _T]) -> _T: + """Run an async coroutine synchronously.""" + return asyncio.run(coro) + + +# ----------------------------------------------------------------------------- +# TestYamuxResponseStream +# ----------------------------------------------------------------------------- + + +class TestYamuxResponseStream: + """Tests for YamuxResponseStream wire format encoding.""" + + def test_send_success_encodes_correctly(self) -> None: + """Success response uses SUCCESS code and encodes SSZ data.""" + + async def run_test() -> tuple[list[bytes], bool]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + ssz_data = b"\x01\x02\x03\x04" + await response.send_success(ssz_data) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert len(written) == 1 + encoded = written[0] + + # First byte should be SUCCESS (0) + assert encoded[0] == ResponseCode.SUCCESS + + # Should decode back to original data + code, decoded = ResponseCode.decode(encoded) + assert code == ResponseCode.SUCCESS + assert decoded == b"\x01\x02\x03\x04" + + def test_send_error_encodes_correctly(self) -> None: + """Error response uses specified code and UTF-8 message.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_error(ResponseCode.INVALID_REQUEST, "Bad request") + + return stream.written + + written = run_async(run_test()) + + assert len(written) == 1 + encoded = written[0] + + # First byte should be INVALID_REQUEST (1) + assert encoded[0] == ResponseCode.INVALID_REQUEST + + # Should decode back to UTF-8 message + code, decoded = ResponseCode.decode(encoded) + assert code == ResponseCode.INVALID_REQUEST + assert decoded == b"Bad request" + + def test_send_error_server_error(self) -> None: + """SERVER_ERROR code encodes correctly.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_error(ResponseCode.SERVER_ERROR, "Internal error") + + return stream.written + + written = run_async(run_test()) + encoded = written[0] + + code, decoded = ResponseCode.decode(encoded) + assert code == ResponseCode.SERVER_ERROR + assert decoded == b"Internal error" + + def test_send_error_resource_unavailable(self) -> None: + """RESOURCE_UNAVAILABLE code encodes correctly.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_error(ResponseCode.RESOURCE_UNAVAILABLE, "Block not found") + + return stream.written + + written = run_async(run_test()) + encoded = written[0] + + code, decoded = ResponseCode.decode(encoded) + assert code == ResponseCode.RESOURCE_UNAVAILABLE + assert decoded == b"Block not found" + + def test_finish_closes_stream(self) -> None: + """Finish closes the underlying stream.""" + + async def run_test() -> bool: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + assert not stream.closed + await response.finish() + return stream.closed + + closed = run_async(run_test()) + assert closed is True + + +# ----------------------------------------------------------------------------- +# TestDefaultRequestHandler - Status +# ----------------------------------------------------------------------------- + + +class TestDefaultRequestHandlerStatus: + """Tests for DefaultRequestHandler.handle_status.""" + + def test_handle_status_returns_our_status(self) -> None: + """Returns our configured status on valid request.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + response = MockResponseStream() + + peer_status = Status( + finalized=Checkpoint(root=Bytes32(b"\xaa" * 32), slot=Slot(50)), + head=Checkpoint(root=Bytes32(b"\xbb" * 32), slot=Slot(150)), + ) + + await handler.handle_status(peer_status, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 1 + + # Decode the SSZ response + returned_status = Status.decode_bytes(successes[0]) + assert returned_status.head.slot == Slot(200) + assert returned_status.finalized.slot == Slot(100) + + def test_handle_status_no_status_returns_error(self) -> None: + """Returns SERVER_ERROR when no status is configured.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + handler = DefaultRequestHandler() # No our_status set + response = MockResponseStream() + + peer_status = make_test_status() + await handler.handle_status(peer_status, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(successes) == 0 + assert len(errors) == 1 + assert errors[0][0] == ResponseCode.SERVER_ERROR + assert "not available" in errors[0][1] + + def test_handle_status_ignores_peer_status(self) -> None: + """Peer's status does not affect our response.""" + + async def run_test() -> bytes: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + response = MockResponseStream() + + # Peer claims different chain state + peer_status = Status( + finalized=Checkpoint(root=Bytes32(b"\xff" * 32), slot=Slot(9999)), + head=Checkpoint(root=Bytes32(b"\xee" * 32), slot=Slot(10000)), + ) + + await handler.handle_status(peer_status, response) + + return response.successes[0] + + ssz_data = run_async(run_test()) + + # Our response is independent of peer's status + returned_status = Status.decode_bytes(ssz_data) + assert returned_status.head.slot == Slot(200) + assert returned_status.finalized.slot == Slot(100) + + +# ----------------------------------------------------------------------------- +# TestDefaultRequestHandler - BlocksByRoot +# ----------------------------------------------------------------------------- + + +class TestDefaultRequestHandlerBlocksByRoot: + """Tests for DefaultRequestHandler.handle_blocks_by_root.""" + + def test_handle_blocks_by_root_returns_found_blocks(self) -> None: + """Sends SUCCESS response for each found block.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + block1 = make_test_block(slot=1, seed=1) + block2 = make_test_block(slot=2, seed=2) + + # Create lookup that returns blocks for specific roots + block_roots: dict[bytes, SignedBlockWithAttestation] = { + b"\x11" * 32: block1, + b"\x22" * 32: block2, + } + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return block_roots.get(bytes(root)) + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + request = BlocksByRootRequest(data=[Bytes32(b"\x11" * 32), Bytes32(b"\x22" * 32)]) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 2 + + # Both blocks should be decodable + decoded1 = SignedBlockWithAttestation.decode_bytes(successes[0]) + decoded2 = SignedBlockWithAttestation.decode_bytes(successes[1]) + + assert decoded1.message.block.slot == Slot(1) + assert decoded2.message.block.slot == Slot(2) + + def test_handle_blocks_by_root_skips_missing_blocks(self) -> None: + """Missing blocks are silently skipped.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + block1 = make_test_block(slot=1, seed=1) + + # Only block1 exists + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(root) == b"\x11" * 32: + return block1 + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + # Request two blocks, only one exists + request = BlocksByRootRequest( + data=[ + Bytes32(b"\x11" * 32), # exists + Bytes32(b"\x99" * 32), # missing + ] + ) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + # Only one block returned, no errors + assert len(errors) == 0 + assert len(successes) == 1 + + def test_handle_blocks_by_root_no_lookup_returns_error(self) -> None: + """Returns SERVER_ERROR when no lookup callback is configured.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + handler = DefaultRequestHandler() # No block_lookup set + response = MockResponseStream() + + request = BlocksByRootRequest(data=[Bytes32(b"\x11" * 32)]) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(successes) == 0 + assert len(errors) == 1 + assert errors[0][0] == ResponseCode.SERVER_ERROR + assert "not available" in errors[0][1] + + def test_handle_blocks_by_root_empty_request(self) -> None: + """Empty request returns no blocks and no errors.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + request = BlocksByRootRequest(data=[]) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 0 + + def test_handle_blocks_by_root_lookup_error_continues(self) -> None: + """Lookup exceptions are caught and processing continues.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + block2 = make_test_block(slot=2, seed=2) + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(root) == b"\x11" * 32: + raise RuntimeError("Database error") + if bytes(root) == b"\x22" * 32: + return block2 + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + # First block causes error, second succeeds + request = BlocksByRootRequest(data=[Bytes32(b"\x11" * 32), Bytes32(b"\x22" * 32)]) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + # Second block still returned despite first lookup failing + assert len(errors) == 0 + assert len(successes) == 1 + + decoded = SignedBlockWithAttestation.decode_bytes(successes[0]) + assert decoded.message.block.slot == Slot(2) + + +# ----------------------------------------------------------------------------- +# TestReqRespServer +# ----------------------------------------------------------------------------- + + +class TestReqRespServer: + """Tests for ReqRespServer request handling.""" + + def test_handle_status_request(self) -> None: + """Full Status request/response flow through ReqRespServer.""" + + async def run_test() -> tuple[list[bytes], bool]: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + server = ReqRespServer(handler=handler) + + # Build wire-format request + peer_status = Status( + finalized=Checkpoint(root=Bytes32(b"\xaa" * 32), slot=Slot(50)), + head=Checkpoint(root=Bytes32(b"\xbb" * 32), slot=Slot(150)), + ) + request_bytes = encode_request(peer_status.encode_bytes()) + + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + # Stream should be closed after handling + assert closed is True + + # Should have received a success response + assert len(written) >= 1 + + # Decode the response + code, ssz_data = ResponseCode.decode(written[0]) + assert code == ResponseCode.SUCCESS + + returned_status = Status.decode_bytes(ssz_data) + assert returned_status.head.slot == Slot(200) + + def test_handle_blocks_by_root_request(self) -> None: + """Full BlocksByRoot request/response flow through ReqRespServer.""" + + async def run_test() -> tuple[list[bytes], bool]: + block1 = make_test_block(slot=1, seed=1) + root1 = Bytes32(b"\x11" * 32) + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(root) == bytes(root1): + return block1 + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + server = ReqRespServer(handler=handler) + + # Build wire-format request + request = BlocksByRootRequest(data=[root1]) + request_bytes = encode_request(request.encode_bytes()) + + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, BLOCKS_BY_ROOT_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, ssz_data = ResponseCode.decode(written[0]) + assert code == ResponseCode.SUCCESS + + returned_block = SignedBlockWithAttestation.decode_bytes(ssz_data) + assert returned_block.message.block.slot == Slot(1) + + def test_empty_request_returns_error(self) -> None: + """Empty request data returns INVALID_REQUEST error.""" + + async def run_test() -> tuple[list[bytes], bool]: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + stream = MockStream(request_data=b"") + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, message = ResponseCode.decode(written[0]) + assert code == ResponseCode.INVALID_REQUEST + assert b"Empty" in message + + def test_decode_error_returns_invalid_request(self) -> None: + """Malformed wire data returns INVALID_REQUEST error.""" + + async def run_test() -> tuple[list[bytes], bool]: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + # Invalid snappy data after length prefix + malformed_data = b"\x10\x00\x00\x00invalid snappy data here" + stream = MockStream(request_data=malformed_data) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, _ = ResponseCode.decode(written[0]) + assert code == ResponseCode.INVALID_REQUEST + + def test_invalid_ssz_returns_invalid_request(self) -> None: + """Valid wire format but invalid SSZ returns INVALID_REQUEST.""" + + async def run_test() -> tuple[list[bytes], bool]: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + # Valid wire format but SSZ is too short for Status (needs 80 bytes) + invalid_ssz = b"\x01\x02\x03\x04" + request_bytes = encode_request(invalid_ssz) + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, message = ResponseCode.decode(written[0]) + assert code == ResponseCode.INVALID_REQUEST + assert b"Invalid Status" in message or b"Status" in message + + def test_unknown_protocol_returns_error(self) -> None: + """Unknown protocol ID returns SERVER_ERROR.""" + + async def run_test() -> tuple[list[bytes], bool]: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + # Valid request data but unknown protocol + status = make_test_status() + request_bytes = encode_request(status.encode_bytes()) + stream = MockStream(request_data=request_bytes) + + unknown_protocol = "/unknown/protocol/1/ssz_snappy" + await server.handle_stream(stream, unknown_protocol) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, message = ResponseCode.decode(written[0]) + assert code == ResponseCode.SERVER_ERROR + assert b"Unknown" in message or b"protocol" in message.lower() + + def test_stream_closed_on_completion(self) -> None: + """Stream is always closed after handling, even on success.""" + + async def run_test() -> bool: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + status = make_test_status() + request_bytes = encode_request(status.encode_bytes()) + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.closed + + closed = run_async(run_test()) + assert closed is True + + def test_stream_closed_on_error(self) -> None: + """Stream is closed even when handling fails.""" + + async def run_test() -> bool: + handler = DefaultRequestHandler() # No status configured + server = ReqRespServer(handler=handler) + + status = make_test_status() + request_bytes = encode_request(status.encode_bytes()) + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.closed + + closed = run_async(run_test()) + assert closed is True + + +# ----------------------------------------------------------------------------- +# TestReqRespProtocolConstants +# ----------------------------------------------------------------------------- + + +class TestReqRespProtocolConstants: + """Tests for protocol ID constants.""" + + def test_protocol_ids_contains_status(self) -> None: + """REQRESP_PROTOCOL_IDS includes status protocol.""" + assert STATUS_PROTOCOL_V1 in REQRESP_PROTOCOL_IDS + + def test_protocol_ids_contains_blocks_by_root(self) -> None: + """REQRESP_PROTOCOL_IDS includes blocks_by_root protocol.""" + assert BLOCKS_BY_ROOT_PROTOCOL_V1 in REQRESP_PROTOCOL_IDS + + def test_protocol_ids_is_frozenset(self) -> None: + """REQRESP_PROTOCOL_IDS is immutable.""" + assert isinstance(REQRESP_PROTOCOL_IDS, frozenset) + + def test_status_protocol_format(self) -> None: + """Status protocol ID follows expected format.""" + assert STATUS_PROTOCOL_V1.startswith("/leanconsensus/req/") + assert STATUS_PROTOCOL_V1.endswith("/ssz_snappy") + assert "status" in STATUS_PROTOCOL_V1 + + def test_blocks_by_root_protocol_format(self) -> None: + """BlocksByRoot protocol ID follows expected format.""" + assert BLOCKS_BY_ROOT_PROTOCOL_V1.startswith("/leanconsensus/req/") + assert BLOCKS_BY_ROOT_PROTOCOL_V1.endswith("/ssz_snappy") + assert "blocks_by_root" in BLOCKS_BY_ROOT_PROTOCOL_V1 + + +# ----------------------------------------------------------------------------- +# TestIntegration - Roundtrip Tests +# ----------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests for full request/response roundtrips.""" + + def test_roundtrip_status_request(self) -> None: + """Full encode -> server -> decode roundtrip for Status.""" + + async def run_test() -> Status: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + server = ReqRespServer(handler=handler) + + # Client side: encode request + peer_status = Status( + finalized=Checkpoint(root=Bytes32(b"\xcc" * 32), slot=Slot(300)), + head=Checkpoint(root=Bytes32(b"\xdd" * 32), slot=Slot(400)), + ) + request_wire = encode_request(peer_status.encode_bytes()) + + # Server side: handle request + stream = MockStream(request_data=request_wire) + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + # Client side: decode response + response_wire = stream.written[0] + code, ssz_bytes = ResponseCode.decode(response_wire) + + assert code == ResponseCode.SUCCESS + return Status.decode_bytes(ssz_bytes) + + returned = run_async(run_test()) + + # Verify we got our status back + assert returned.head.slot == Slot(200) + assert returned.finalized.slot == Slot(100) + + def test_roundtrip_blocks_by_root_request(self) -> None: + """Full encode -> server -> decode roundtrip for BlocksByRoot.""" + + async def run_test() -> list[SignedBlockWithAttestation]: + block1 = make_test_block(slot=10, seed=10) + block2 = make_test_block(slot=20, seed=20) + + root1 = Bytes32(b"\xaa" * 32) + root2 = Bytes32(b"\xbb" * 32) + + blocks_by_root: dict[bytes, SignedBlockWithAttestation] = { + bytes(root1): block1, + bytes(root2): block2, + } + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return blocks_by_root.get(bytes(root)) + + handler = DefaultRequestHandler(block_lookup=lookup) + server = ReqRespServer(handler=handler) + + # Client side: encode request + request = BlocksByRootRequest(data=[root1, root2]) + request_wire = encode_request(request.encode_bytes()) + + # Server side: handle request + stream = MockStream(request_data=request_wire) + await server.handle_stream(stream, BLOCKS_BY_ROOT_PROTOCOL_V1) + + # Client side: decode responses + results = [] + for response_wire in stream.written: + code, ssz_bytes = ResponseCode.decode(response_wire) + if code == ResponseCode.SUCCESS: + results.append(SignedBlockWithAttestation.decode_bytes(ssz_bytes)) + + return results + + blocks = run_async(run_test()) + + assert len(blocks) == 2 + slots = {b.message.block.slot for b in blocks} + assert Slot(10) in slots + assert Slot(20) in slots + + def test_roundtrip_blocks_by_root_partial_response(self) -> None: + """BlocksByRoot returns only available blocks.""" + + async def run_test() -> list[SignedBlockWithAttestation]: + block1 = make_test_block(slot=10, seed=10) + + root1 = Bytes32(b"\xaa" * 32) + root_missing = Bytes32(b"\x00" * 32) + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(root) == bytes(root1): + return block1 + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + server = ReqRespServer(handler=handler) + + # Request two blocks, only one exists + request = BlocksByRootRequest(data=[root1, root_missing]) + request_wire = encode_request(request.encode_bytes()) + + stream = MockStream(request_data=request_wire) + await server.handle_stream(stream, BLOCKS_BY_ROOT_PROTOCOL_V1) + + results = [] + for response_wire in stream.written: + code, ssz_bytes = ResponseCode.decode(response_wire) + if code == ResponseCode.SUCCESS: + results.append(SignedBlockWithAttestation.decode_bytes(ssz_bytes)) + + return results + + blocks = run_async(run_test()) + + # Only one block returned + assert len(blocks) == 1 + assert blocks[0].message.block.slot == Slot(10) + + +# ----------------------------------------------------------------------------- +# TestResponseStreamProtocol +# ----------------------------------------------------------------------------- + + +class TestResponseStreamProtocol: + """Tests verifying ResponseStream protocol compliance.""" + + def test_mock_response_stream_is_protocol_compliant(self) -> None: + """MockResponseStream implements ResponseStream protocol.""" + # This test verifies our mock is usable with the handler + mock = MockResponseStream() + + # Should have the required methods + assert hasattr(mock, "send_success") + assert hasattr(mock, "send_error") + assert hasattr(mock, "finish") + + # Methods should be callable + assert callable(mock.send_success) + assert callable(mock.send_error) + assert callable(mock.finish) + + def test_yamux_response_stream_is_protocol_compliant(self) -> None: + """YamuxResponseStream implements ResponseStream protocol.""" + stream = MockStream() + yamux = YamuxResponseStream(_stream=stream) + + assert hasattr(yamux, "send_success") + assert hasattr(yamux, "send_error") + assert hasattr(yamux, "finish") + + +# ----------------------------------------------------------------------------- +# TestBlockLookupTypeAlias +# ----------------------------------------------------------------------------- + + +class TestBlockLookupTypeAlias: + """Tests for BlockLookup type alias usage.""" + + def test_async_function_matches_block_lookup_signature(self) -> None: + """Verify async function can be used as BlockLookup.""" + + async def my_lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return None + + # Should type-check as BlockLookup + lookup: BlockLookup = my_lookup + + async def run_test() -> SignedBlockWithAttestation | None: + return await lookup(Bytes32(b"\x00" * 32)) + + result = run_async(run_test()) + assert result is None + + def test_block_lookup_returning_block(self) -> None: + """BlockLookup returning a block works correctly.""" + block = make_test_block(slot=42, seed=42) + + async def my_lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return block + + async def run_test() -> SignedBlockWithAttestation | None: + return await my_lookup(Bytes32(b"\x00" * 32)) + + result = run_async(run_test()) + assert result is not None + assert result.message.block.slot == Slot(42) + + +# ----------------------------------------------------------------------------- +# TestYamuxResponseStreamMultipleResponses +# ----------------------------------------------------------------------------- + + +class TestYamuxResponseStreamMultipleResponses: + """Tests for YamuxResponseStream with multiple responses in sequence.""" + + def test_send_multiple_success_responses(self) -> None: + """Multiple SUCCESS responses are written independently.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_success(b"\x01\x02") + await response.send_success(b"\x03\x04") + await response.send_success(b"\x05\x06") + + return stream.written + + written = run_async(run_test()) + + assert len(written) == 3 + + # Each response should be independently decodable + for i, data in enumerate(written): + code, decoded = ResponseCode.decode(data) + assert code == ResponseCode.SUCCESS + expected = bytes([i * 2 + 1, i * 2 + 2]) + assert decoded == expected + + def test_send_success_then_error(self) -> None: + """Success response followed by error response.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_success(b"\xaa\xbb") + await response.send_error(ResponseCode.RESOURCE_UNAVAILABLE, "Done") + + return stream.written + + written = run_async(run_test()) + + assert len(written) == 2 + + code1, data1 = ResponseCode.decode(written[0]) + assert code1 == ResponseCode.SUCCESS + assert data1 == b"\xaa\xbb" + + code2, data2 = ResponseCode.decode(written[1]) + assert code2 == ResponseCode.RESOURCE_UNAVAILABLE + assert data2 == b"Done" + + def test_send_empty_success_response(self) -> None: + """Empty SUCCESS response payload is handled.""" + + async def run_test() -> list[bytes]: + stream = MockStream() + response = YamuxResponseStream(_stream=stream) + + await response.send_success(b"") + + return stream.written + + written = run_async(run_test()) + + assert len(written) == 1 + code, decoded = ResponseCode.decode(written[0]) + assert code == ResponseCode.SUCCESS + assert decoded == b"" + + +# ----------------------------------------------------------------------------- +# TestMockStreamChunkedRead +# ----------------------------------------------------------------------------- + + +class MockChunkedStream: + """Mock stream that returns data in multiple chunks.""" + + def __init__(self, chunks: list[bytes]) -> None: + """Initialize with a list of chunks to return.""" + self.chunks = chunks + self.chunk_index = 0 + self.written: list[bytes] = [] + self.closed = False + + @property + def protocol_id(self) -> str: + """Mock protocol ID.""" + return STATUS_PROTOCOL_V1 + + async def read(self, n: int = -1) -> bytes: + """Return chunks one at a time.""" + if self.chunk_index >= len(self.chunks): + return b"" + chunk = self.chunks[self.chunk_index] + self.chunk_index += 1 + return chunk + + async def write(self, data: bytes) -> None: + """Accumulate written data.""" + self.written.append(data) + + async def close(self) -> None: + """Mark stream as closed.""" + self.closed = True + + async def reset(self) -> None: + """Abort the stream.""" + self.closed = True + + +class TestReqRespServerChunkedRead: + """Tests for ReqRespServer handling chunked request data.""" + + def test_handle_chunked_status_request(self) -> None: + """Request data arriving in multiple chunks is assembled correctly.""" + + async def run_test() -> tuple[list[bytes], bool]: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + server = ReqRespServer(handler=handler) + + # Build wire-format request + peer_status = make_test_status() + request_bytes = encode_request(peer_status.encode_bytes()) + + # Split into multiple chunks + mid = len(request_bytes) // 2 + chunks = [request_bytes[:mid], request_bytes[mid:]] + + stream = MockChunkedStream(chunks=chunks) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, ssz_data = ResponseCode.decode(written[0]) + assert code == ResponseCode.SUCCESS + + def test_handle_single_byte_chunks(self) -> None: + """Request data arriving one byte at a time is handled.""" + + async def run_test() -> tuple[list[bytes], bool]: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + server = ReqRespServer(handler=handler) + + peer_status = make_test_status() + request_bytes = encode_request(peer_status.encode_bytes()) + + # Split into single-byte chunks + chunks = [bytes([b]) for b in request_bytes] + + stream = MockChunkedStream(chunks=chunks) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, _ = ResponseCode.decode(written[0]) + assert code == ResponseCode.SUCCESS + + +# ----------------------------------------------------------------------------- +# TestReqRespServerEdgeCases +# ----------------------------------------------------------------------------- + + +class TestReqRespServerEdgeCases: + """Edge cases for ReqRespServer.""" + + def test_invalid_blocks_by_root_ssz(self) -> None: + """Invalid SSZ for BlocksByRoot returns INVALID_REQUEST.""" + + async def run_test() -> tuple[list[bytes], bool]: + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + server = ReqRespServer(handler=handler) + + # Valid wire format but wrong SSZ structure for BlocksByRootRequest + # BlocksByRootRequest expects list of Bytes32, not arbitrary bytes + invalid_ssz = b"\xff" * 10 + request_bytes = encode_request(invalid_ssz) + stream = MockStream(request_data=request_bytes) + + await server.handle_stream(stream, BLOCKS_BY_ROOT_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, message = ResponseCode.decode(written[0]) + assert code == ResponseCode.INVALID_REQUEST + + def test_truncated_varint_returns_error(self) -> None: + """Truncated varint in request returns INVALID_REQUEST.""" + + async def run_test() -> tuple[list[bytes], bool]: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + # Varint with continuation bit set but no following byte + truncated_varint = b"\x80" + stream = MockStream(request_data=truncated_varint) + + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.written, stream.closed + + written, closed = run_async(run_test()) + + assert closed is True + assert len(written) >= 1 + + code, _ = ResponseCode.decode(written[0]) + assert code == ResponseCode.INVALID_REQUEST + + +# ----------------------------------------------------------------------------- +# TestDefaultRequestHandlerEdgeCases +# ----------------------------------------------------------------------------- + + +class TestDefaultRequestHandlerEdgeCases: + """Edge cases for DefaultRequestHandler.""" + + def test_blocks_by_root_single_block(self) -> None: + """Single block request returns correctly.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + block = make_test_block(slot=999, seed=99) + root = Bytes32(b"\x99" * 32) + + async def lookup(r: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(r) == bytes(root): + return block + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + request = BlocksByRootRequest(data=[root]) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 1 + + decoded = SignedBlockWithAttestation.decode_bytes(successes[0]) + assert decoded.message.block.slot == Slot(999) + + def test_blocks_by_root_all_missing(self) -> None: + """Request where all blocks are missing returns no success responses.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return None + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + request = BlocksByRootRequest( + data=[ + Bytes32(b"\x11" * 32), + Bytes32(b"\x22" * 32), + Bytes32(b"\x33" * 32), + ] + ) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 0 + + def test_blocks_by_root_mixed_found_missing(self) -> None: + """Mixed found/missing blocks returns only found blocks.""" + + async def run_test() -> tuple[list[bytes], list[tuple[ResponseCode, str]]]: + block1 = make_test_block(slot=1, seed=1) + block3 = make_test_block(slot=3, seed=3) + + blocks: dict[bytes, SignedBlockWithAttestation] = { + b"\x11" * 32: block1, + # \x22 missing + b"\x33" * 32: block3, + } + + async def lookup(root: Bytes32) -> SignedBlockWithAttestation | None: + return blocks.get(bytes(root)) + + handler = DefaultRequestHandler(block_lookup=lookup) + response = MockResponseStream() + + request = BlocksByRootRequest( + data=[ + Bytes32(b"\x11" * 32), + Bytes32(b"\x22" * 32), + Bytes32(b"\x33" * 32), + ] + ) + + await handler.handle_blocks_by_root(request, response) + + return response.successes, response.errors + + successes, errors = run_async(run_test()) + + assert len(errors) == 0 + assert len(successes) == 2 + + # Verify order is preserved + decoded1 = SignedBlockWithAttestation.decode_bytes(successes[0]) + decoded2 = SignedBlockWithAttestation.decode_bytes(successes[1]) + + assert decoded1.message.block.slot == Slot(1) + assert decoded2.message.block.slot == Slot(3) + + def test_status_update_after_initialization(self) -> None: + """Status can be updated after handler creation.""" + + async def run_test() -> tuple[list[bytes], list[bytes]]: + handler = DefaultRequestHandler() + response1 = MockResponseStream() + + # First request with no status + await handler.handle_status(make_test_status(), response1) + + # Update status + handler.our_status = make_test_status() + + response2 = MockResponseStream() + await handler.handle_status(make_test_status(), response2) + + return response1.successes, response2.successes + + successes1, successes2 = run_async(run_test()) + + # First request should fail + assert len(successes1) == 0 + + # Second request should succeed + assert len(successes2) == 1 + + +# ----------------------------------------------------------------------------- +# TestConcurrentRequestHandling +# ----------------------------------------------------------------------------- + + +class TestConcurrentRequestHandling: + """Tests for concurrent request handling.""" + + def test_concurrent_status_requests(self) -> None: + """Multiple concurrent status requests are handled independently.""" + + async def run_test() -> list[Status]: + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status) + server = ReqRespServer(handler=handler) + + # Create multiple streams with requests + streams: list[MockStream] = [] + for i in range(3): + peer_status = Status( + finalized=Checkpoint(root=Bytes32(bytes([i]) * 32), slot=Slot(i * 10)), + head=Checkpoint(root=Bytes32(bytes([i + 10]) * 32), slot=Slot(i * 20)), + ) + request_bytes = encode_request(peer_status.encode_bytes()) + streams.append(MockStream(request_data=request_bytes)) + + # Handle all requests concurrently + await asyncio.gather(*[server.handle_stream(s, STATUS_PROTOCOL_V1) for s in streams]) + + # Decode all responses + results = [] + for stream in streams: + assert stream.closed + assert len(stream.written) >= 1 + code, ssz_data = ResponseCode.decode(stream.written[0]) + assert code == ResponseCode.SUCCESS + results.append(Status.decode_bytes(ssz_data)) + + return results + + results = run_async(run_test()) + + # All responses should be our status + for status in results: + assert status.head.slot == Slot(200) + assert status.finalized.slot == Slot(100) + + def test_concurrent_mixed_requests(self) -> None: + """Concurrent Status and BlocksByRoot requests.""" + + async def run_test() -> tuple[list[Status], list[SignedBlockWithAttestation]]: + block = make_test_block(slot=42, seed=42) + root = Bytes32(b"\x42" * 32) + + async def lookup(r: Bytes32) -> SignedBlockWithAttestation | None: + if bytes(r) == bytes(root): + return block + return None + + our_status = make_test_status() + handler = DefaultRequestHandler(our_status=our_status, block_lookup=lookup) + server = ReqRespServer(handler=handler) + + # Status request + status_request = encode_request(make_test_status().encode_bytes()) + status_stream = MockStream(request_data=status_request) + + # BlocksByRoot request + blocks_request = encode_request(BlocksByRootRequest(data=[root]).encode_bytes()) + blocks_stream = MockStream(request_data=blocks_request) + + # Handle concurrently + await asyncio.gather( + server.handle_stream(status_stream, STATUS_PROTOCOL_V1), + server.handle_stream(blocks_stream, BLOCKS_BY_ROOT_PROTOCOL_V1), + ) + + # Decode status response + code, ssz_data = ResponseCode.decode(status_stream.written[0]) + assert code == ResponseCode.SUCCESS + status_result = Status.decode_bytes(ssz_data) + + # Decode block response + code, ssz_data = ResponseCode.decode(blocks_stream.written[0]) + assert code == ResponseCode.SUCCESS + block_result = SignedBlockWithAttestation.decode_bytes(ssz_data) + + return [status_result], [block_result] + + statuses, blocks = run_async(run_test()) + + assert len(statuses) == 1 + assert statuses[0].head.slot == Slot(200) + + assert len(blocks) == 1 + assert blocks[0].message.block.slot == Slot(42) + + +# ----------------------------------------------------------------------------- +# TestHandlerExceptionRecovery +# ----------------------------------------------------------------------------- + + +class MockFailingStream: + """Mock stream that raises exceptions on specific operations.""" + + def __init__( + self, + request_data: bytes = b"", + fail_on_write: bool = False, + fail_on_close: bool = False, + ) -> None: + """Initialize with failure modes.""" + self.request_data = request_data + self.fail_on_write = fail_on_write + self.fail_on_close = fail_on_close + self._read_offset = 0 + self.written: list[bytes] = [] + self.closed = False + self.close_attempts = 0 + + @property + def protocol_id(self) -> str: + """Mock protocol ID.""" + return STATUS_PROTOCOL_V1 + + async def read(self, n: int = -1) -> bytes: + """Return request data.""" + if self._read_offset >= len(self.request_data): + return b"" + chunk = self.request_data[self._read_offset :] + self._read_offset = len(self.request_data) + return chunk + + async def write(self, data: bytes) -> None: + """Optionally fail on write.""" + if self.fail_on_write: + raise ConnectionError("Write failed") + self.written.append(data) + + async def close(self) -> None: + """Optionally fail on close.""" + self.close_attempts += 1 + if self.fail_on_close: + raise ConnectionError("Close failed") + self.closed = True + + async def reset(self) -> None: + """Abort the stream.""" + self.closed = True + + +class TestHandlerExceptionRecovery: + """Tests for exception handling and recovery.""" + + def test_stream_closed_despite_close_exception(self) -> None: + """Stream close is attempted even if it raises an exception.""" + + async def run_test() -> int: + handler = DefaultRequestHandler(our_status=make_test_status()) + server = ReqRespServer(handler=handler) + + request_bytes = encode_request(make_test_status().encode_bytes()) + stream = MockFailingStream( + request_data=request_bytes, + fail_on_close=True, + ) + + # Should not raise, exception is caught + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.close_attempts + + close_attempts = run_async(run_test()) + + # Close should be attempted + assert close_attempts >= 1 + + def test_error_response_sent_despite_write_exception(self) -> None: + """Error handling continues even when write fails.""" + + async def run_test() -> int: + handler = DefaultRequestHandler() # No status + server = ReqRespServer(handler=handler) + + request_bytes = encode_request(make_test_status().encode_bytes()) + stream = MockFailingStream( + request_data=request_bytes, + fail_on_write=True, + ) + + # Should not raise, writes that fail are caught + await server.handle_stream(stream, STATUS_PROTOCOL_V1) + + return stream.close_attempts + + close_attempts = run_async(run_test()) + + # Close should still be attempted after write failure + assert close_attempts >= 1 + + +# ----------------------------------------------------------------------------- +# TestRequestHandlerConstant +# ----------------------------------------------------------------------------- + + +class TestRequestTimeoutConstant: + """Tests for REQUEST_TIMEOUT_SECONDS constant.""" + + def test_timeout_is_positive(self) -> None: + """Request timeout is a positive number.""" + from lean_spec.subspecs.networking.reqresp.handler import REQUEST_TIMEOUT_SECONDS + + assert REQUEST_TIMEOUT_SECONDS > 0 + + def test_timeout_is_reasonable(self) -> None: + """Request timeout is within reasonable bounds.""" + from lean_spec.subspecs.networking.reqresp.handler import REQUEST_TIMEOUT_SECONDS + + # Should be at least a few seconds + assert REQUEST_TIMEOUT_SECONDS >= 1.0 + # Should not be excessively long + assert REQUEST_TIMEOUT_SECONDS <= 60.0 + + +# ----------------------------------------------------------------------------- +# TestMockStreamProtocolCompliance +# ----------------------------------------------------------------------------- + + +class TestMockStreamProtocolCompliance: + """Tests verifying mock streams match the Stream protocol.""" + + def test_mock_stream_has_protocol_id(self) -> None: + """MockStream has protocol_id property.""" + stream = MockStream() + assert hasattr(stream, "protocol_id") + assert isinstance(stream.protocol_id, str) + + def test_mock_stream_has_read_method(self) -> None: + """MockStream has read method.""" + stream = MockStream() + assert hasattr(stream, "read") + assert callable(stream.read) + + def test_mock_stream_has_write_method(self) -> None: + """MockStream has write method.""" + stream = MockStream() + assert hasattr(stream, "write") + assert callable(stream.write) + + def test_mock_stream_has_close_method(self) -> None: + """MockStream has close method.""" + stream = MockStream() + assert hasattr(stream, "close") + assert callable(stream.close) + + def test_mock_stream_has_reset_method(self) -> None: + """MockStream has reset method.""" + stream = MockStream() + assert hasattr(stream, "reset") + assert callable(stream.reset) + + def test_mock_stream_reset_closes_stream(self) -> None: + """MockStream reset marks stream as closed.""" + + async def run_test() -> bool: + stream = MockStream() + await stream.reset() + return stream.closed + + closed = run_async(run_test()) + assert closed is True diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index e545e59d..c935bf5a 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -23,6 +23,7 @@ from lean_spec.subspecs.validator import ValidatorRegistry, ValidatorService from lean_spec.subspecs.validator.registry import ValidatorEntry from lean_spec.types import Uint64 +from lean_spec.types.validator import is_proposer class MockNetworkRequester(NetworkRequester): @@ -274,3 +275,186 @@ async def check_sleep() -> None: expected = float(genesis) - current_time # 100 seconds assert captured_duration is not None assert abs(captured_duration - expected) < 0.001 + + +class TestProposerSkipping: + """Tests for proposer skipping during attestation production.""" + + def test_proposer_skipped_in_attestation_production( + self, + sync_service: SyncService, + ) -> None: + """Proposer is skipped when producing attestations at interval 1. + + At slot 0, validator 0 is the proposer (0 % 3 == 0). + When controlling validators 0 and 1, only validator 1 should produce an attestation + since validator 0 already attested within their block. + """ + clock = SlotClock(genesis_time=Uint64(0)) + + # Registry with validators 0 and 1. + registry = ValidatorRegistry() + for i in [0, 1]: + mock_key = MagicMock() + registry.add(ValidatorEntry(index=Uint64(i), secret_key=mock_key)) + + # Track which validators had _sign_attestation called. + signed_validator_ids: list[Uint64] = [] + + def mock_sign_attestation( + self: ValidatorService, # noqa: ARG001 + attestation_data: object, # noqa: ARG001 + validator_index: Uint64, + ) -> SignedAttestation: + signed_validator_ids.append(validator_index) + return MagicMock(spec=SignedAttestation, validator_id=validator_index) + + service = ValidatorService( + sync_service=sync_service, + clock=clock, + registry=registry, + ) + + async def produce() -> None: + # Slot 0: validator 0 is proposer (0 % 3 == 0). + with patch.object( + ValidatorService, + "_sign_attestation", + mock_sign_attestation, + ): + await service._produce_attestations(Slot(0)) + + asyncio.run(produce()) + + # Only validator 1 should have signed an attestation. + assert len(signed_validator_ids) == 1 + assert signed_validator_ids[0] == Uint64(1) + assert service.attestations_produced == 1 + + def test_non_proposer_still_attests( + self, + sync_service: SyncService, + ) -> None: + """Non-proposer validators still produce attestations. + + At slot 1, validator 1 is the proposer (1 % 3 == 1). + Validator 0 is not the proposer so should produce an attestation. + """ + clock = SlotClock(genesis_time=Uint64(0)) + + # Registry with only validator 0. + registry = ValidatorRegistry() + mock_key = MagicMock() + registry.add(ValidatorEntry(index=Uint64(0), secret_key=mock_key)) + + # Track which validators had _sign_attestation called. + signed_validator_ids: list[Uint64] = [] + + def mock_sign_attestation( + self: ValidatorService, # noqa: ARG001 + attestation_data: object, # noqa: ARG001 + validator_index: Uint64, + ) -> SignedAttestation: + signed_validator_ids.append(validator_index) + return MagicMock(spec=SignedAttestation, validator_id=validator_index) + + service = ValidatorService( + sync_service=sync_service, + clock=clock, + registry=registry, + ) + + async def produce() -> None: + # Slot 1: validator 1 is proposer (1 % 3 == 1). + # Validator 0 is not proposer, should attest. + with patch.object( + ValidatorService, + "_sign_attestation", + mock_sign_attestation, + ): + await service._produce_attestations(Slot(1)) + + asyncio.run(produce()) + + # Validator 0 should have signed an attestation. + assert len(signed_validator_ids) == 1 + assert signed_validator_ids[0] == Uint64(0) + assert service.attestations_produced == 1 + + def test_multiple_validators_only_non_proposers_attest( + self, + sync_service: SyncService, + ) -> None: + """With multiple validators, only non-proposers produce attestations. + + At slot 2, validator 2 is the proposer (2 % 3 == 2). + Controlling validators 0, 1, and 2, only validators 0 and 1 should attest. + """ + clock = SlotClock(genesis_time=Uint64(0)) + + # Registry with validators 0, 1, and 2. + registry = ValidatorRegistry() + for i in [0, 1, 2]: + mock_key = MagicMock() + registry.add(ValidatorEntry(index=Uint64(i), secret_key=mock_key)) + + # Track which validators had _sign_attestation called. + signed_validator_ids: list[Uint64] = [] + + def mock_sign_attestation( + self: ValidatorService, # noqa: ARG001 + attestation_data: object, # noqa: ARG001 + validator_index: Uint64, + ) -> SignedAttestation: + signed_validator_ids.append(validator_index) + return MagicMock(spec=SignedAttestation, validator_id=validator_index) + + service = ValidatorService( + sync_service=sync_service, + clock=clock, + registry=registry, + ) + + async def produce() -> None: + # Slot 2: validator 2 is proposer (2 % 3 == 2). + with patch.object( + ValidatorService, + "_sign_attestation", + mock_sign_attestation, + ): + await service._produce_attestations(Slot(2)) + + asyncio.run(produce()) + + # Validators 0 and 1 should have signed attestations. + assert len(signed_validator_ids) == 2 + assert set(signed_validator_ids) == {Uint64(0), Uint64(1)} + assert service.attestations_produced == 2 + + # Verify validator 2 (proposer) did not sign. + assert Uint64(2) not in signed_validator_ids + + def test_is_proposer_consistency_with_skip_logic( + self, + genesis_state: State, + ) -> None: + """The is_proposer function correctly identifies proposers. + + Verifies the proposer selection logic used by the skip mechanism. + """ + num_validators = Uint64(len(genesis_state.validators)) + + # At each slot, exactly one validator is the proposer. + for slot in range(6): + proposer_count = sum( + 1 + for i in range(int(num_validators)) + if is_proposer(Uint64(i), Uint64(slot), num_validators) + ) + assert proposer_count == 1 + + # Verify round-robin pattern. + assert is_proposer(Uint64(0), Uint64(0), num_validators) + assert is_proposer(Uint64(1), Uint64(1), num_validators) + assert is_proposer(Uint64(2), Uint64(2), num_validators) + assert is_proposer(Uint64(0), Uint64(3), num_validators) # Wraps around diff --git a/tests/lean_spec/test_cli.py b/tests/lean_spec/test_cli.py new file mode 100644 index 00000000..daf366df --- /dev/null +++ b/tests/lean_spec/test_cli.py @@ -0,0 +1,544 @@ +"""Tests for CLI functions. + +Tests the ENR detection, bootnode resolution, and checkpoint sync functionality +used by the CLI. +""" + +from __future__ import annotations + +import asyncio +import base64 +from unittest.mock import AsyncMock, patch + +import pytest + +from lean_spec.__main__ import create_anchor_block, is_enr_string, resolve_bootnode +from lean_spec.subspecs.containers import Block, BlockBody +from lean_spec.subspecs.containers.block.types import AggregatedAttestations +from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.ssz.hash import hash_tree_root +from lean_spec.types import Bytes32, Uint64 +from lean_spec.types.rlp import encode as rlp_encode +from tests.lean_spec.helpers import make_genesis_state + + +# Valid ENR with IPv4 and TCP port (derived from EIP-778 test vector structure) +# This ENR has: ip=192.168.1.1, tcp=9000 +def _make_enr_with_tcp(ip_bytes: bytes, tcp_port: int) -> str: + """Create a minimal ENR string with IPv4 and TCP port.""" + rlp_data = rlp_encode( + [ + b"\x00" * 64, # signature + b"\x01", # seq = 1 + b"id", + b"v4", + b"ip", + ip_bytes, + b"secp256k1", + b"\x02" + b"\x00" * 32, # compressed pubkey + b"tcp", + tcp_port.to_bytes(2, "big"), + ] + ) + b64_content = base64.urlsafe_b64encode(rlp_data).decode("utf-8").rstrip("=") + return f"enr:{b64_content}" + + +def _make_enr_with_ipv6_tcp(ip6_bytes: bytes, tcp_port: int) -> str: + """Create a minimal ENR string with IPv6 and TCP port.""" + rlp_data = rlp_encode( + [ + b"\x00" * 64, # signature + b"\x01", # seq = 1 + b"id", + b"v4", + b"ip6", + ip6_bytes, + b"secp256k1", + b"\x02" + b"\x00" * 32, # compressed pubkey + b"tcp", + tcp_port.to_bytes(2, "big"), + ] + ) + b64_content = base64.urlsafe_b64encode(rlp_data).decode("utf-8").rstrip("=") + return f"enr:{b64_content}" + + +def _make_enr_without_tcp(ip_bytes: bytes) -> str: + """Create an ENR string with IPv4 but no TCP port (UDP only).""" + rlp_data = rlp_encode( + [ + b"\x00" * 64, # signature + b"\x01", # seq = 1 + b"id", + b"v4", + b"ip", + ip_bytes, + b"secp256k1", + b"\x02" + b"\x00" * 32, # compressed pubkey + b"udp", + (30303).to_bytes(2, "big"), # UDP only, no TCP + ] + ) + b64_content = base64.urlsafe_b64encode(rlp_data).decode("utf-8").rstrip("=") + return f"enr:{b64_content}" + + +# Pre-built test ENRs +ENR_WITH_TCP = _make_enr_with_tcp(b"\xc0\xa8\x01\x01", 9000) # 192.168.1.1:9000 +ENR_WITH_IPV6_TCP = _make_enr_with_ipv6_tcp(b"\x00" * 15 + b"\x01", 9000) # ::1:9000 +ENR_WITHOUT_TCP = _make_enr_without_tcp(b"\xc0\xa8\x01\x01") # 192.168.1.1, UDP only + +# Valid multiaddr strings +MULTIADDR_IPV4 = "/ip4/127.0.0.1/tcp/9000" +MULTIADDR_IPV6 = "/ip6/::1/tcp/9000" + + +class TestIsEnrString: + """Tests for is_enr_string() detection function.""" + + def test_enr_string_detected(self) -> None: + """Valid ENR prefix returns True.""" + assert is_enr_string("enr:-IS4QHCYrYZbAKW...") is True + + def test_enr_prefix_minimal(self) -> None: + """Minimal ENR prefix 'enr:' returns True.""" + assert is_enr_string("enr:") is True + + def test_enr_with_valid_content(self) -> None: + """Full valid ENR string returns True.""" + assert is_enr_string(ENR_WITH_TCP) is True + + def test_multiaddr_not_detected(self) -> None: + """Multiaddr string returns False.""" + assert is_enr_string(MULTIADDR_IPV4) is False + assert is_enr_string(MULTIADDR_IPV6) is False + + def test_empty_string(self) -> None: + """Empty string returns False.""" + assert is_enr_string("") is False + + def test_enode_not_detected(self) -> None: + """enode:// format returns False.""" + enode = "enode://abc123@127.0.0.1:30303" + assert is_enr_string(enode) is False + + def test_similar_prefix_not_detected(self) -> None: + """Strings with similar but incorrect prefixes return False.""" + assert is_enr_string("ENR:") is False # Case sensitive + assert is_enr_string("enr") is False # Missing colon + assert is_enr_string("enr-") is False # Wrong separator + assert is_enr_string("enrs:") is False # Extra character + + def test_whitespace_prefix_not_detected(self) -> None: + """Whitespace before prefix returns False.""" + assert is_enr_string(" enr:abc") is False + assert is_enr_string("\tenr:abc") is False + + +class TestResolveBootnode: + """Tests for resolve_bootnode() resolution function.""" + + def test_resolve_multiaddr_unchanged(self) -> None: + """Multiaddr strings are returned unchanged.""" + assert resolve_bootnode(MULTIADDR_IPV4) == MULTIADDR_IPV4 + assert resolve_bootnode(MULTIADDR_IPV6) == MULTIADDR_IPV6 + + def test_resolve_arbitrary_multiaddr_unchanged(self) -> None: + """Any non-ENR string passes through unchanged.""" + # The function does not validate multiaddr format + arbitrary = "/some/arbitrary/path" + assert resolve_bootnode(arbitrary) == arbitrary + + def test_resolve_valid_enr_with_tcp(self) -> None: + """ENR with IPv4+TCP extracts multiaddr correctly.""" + result = resolve_bootnode(ENR_WITH_TCP) + assert result == "/ip4/192.168.1.1/tcp/9000" + + def test_resolve_enr_ipv6(self) -> None: + """ENR with IPv6+TCP extracts multiaddr correctly.""" + result = resolve_bootnode(ENR_WITH_IPV6_TCP) + # IPv6 loopback ::1 formatted as full hex + assert "/ip6/" in result + assert "/tcp/9000" in result + + def test_resolve_enr_without_tcp_raises(self) -> None: + """ENR without TCP port raises ValueError.""" + with pytest.raises(ValueError, match=r"no TCP connection info"): + resolve_bootnode(ENR_WITHOUT_TCP) + + def test_resolve_invalid_enr_raises(self) -> None: + """Malformed ENR raises ValueError.""" + # Valid base64 but invalid RLP structure + with pytest.raises(ValueError, match=r"Invalid RLP"): + resolve_bootnode("enr:YWJj") # "abc" in base64, not valid RLP structure + + # Another invalid RLP - too short for ENR + with pytest.raises(ValueError, match=r"(Invalid RLP|at least signature)"): + resolve_bootnode("enr:wA") # Single byte 0xc0 = empty list + + def test_resolve_enr_prefix_only_raises(self) -> None: + """ENR with prefix only (no content) raises ValueError.""" + with pytest.raises(ValueError): + resolve_bootnode("enr:") + + def test_resolve_enr_with_different_ports(self) -> None: + """ENR resolution handles various port numbers.""" + # Port 30303 + enr_30303 = _make_enr_with_tcp(b"\x7f\x00\x00\x01", 30303) + result = resolve_bootnode(enr_30303) + assert result == "/ip4/127.0.0.1/tcp/30303" + + # Port 1 (minimum valid) + enr_1 = _make_enr_with_tcp(b"\x7f\x00\x00\x01", 1) + result = resolve_bootnode(enr_1) + assert result == "/ip4/127.0.0.1/tcp/1" + + # Port 65535 (maximum) + enr_max = _make_enr_with_tcp(b"\x7f\x00\x00\x01", 65535) + result = resolve_bootnode(enr_max) + assert result == "/ip4/127.0.0.1/tcp/65535" + + def test_resolve_enr_with_different_ips(self) -> None: + """ENR resolution handles various IPv4 addresses.""" + test_cases = [ + (b"\x00\x00\x00\x00", "0.0.0.0"), + (b"\xff\xff\xff\xff", "255.255.255.255"), + (b"\x0a\x00\x00\x01", "10.0.0.1"), + ] + for ip_bytes, expected_ip in test_cases: + enr = _make_enr_with_tcp(ip_bytes, 9000) + result = resolve_bootnode(enr) + assert result == f"/ip4/{expected_ip}/tcp/9000" + + +class TestMixedBootnodes: + """Integration tests for mixed bootnode types.""" + + def test_mixed_bootnodes_list(self) -> None: + """Process a list containing both ENR and multiaddr.""" + bootnodes = [ + MULTIADDR_IPV4, + ENR_WITH_TCP, + "/ip4/10.0.0.1/tcp/8000", + ] + + resolved = [resolve_bootnode(b) for b in bootnodes] + + assert resolved[0] == MULTIADDR_IPV4 + assert resolved[1] == "/ip4/192.168.1.1/tcp/9000" + assert resolved[2] == "/ip4/10.0.0.1/tcp/8000" + + def test_filter_invalid_enrs(self) -> None: + """Demonstrate filtering out invalid ENRs from a bootnode list.""" + bootnodes = [ + MULTIADDR_IPV4, + ENR_WITHOUT_TCP, # Invalid - no TCP + ENR_WITH_TCP, + ] + + resolved = [] + for bootnode in bootnodes: + try: + resolved.append(resolve_bootnode(bootnode)) + except ValueError: + continue # Skip invalid + + assert len(resolved) == 2 + assert resolved[0] == MULTIADDR_IPV4 + assert resolved[1] == "/ip4/192.168.1.1/tcp/9000" + + +# ============================================================================= +# Checkpoint Sync Tests +# ============================================================================= + + +class TestCreateAnchorBlock: + """Tests for create_anchor_block() function.""" + + def test_computes_state_root_when_zero(self) -> None: + """State root is computed when header has zero state root.""" + # Arrange: Create a genesis state (header has zero state root) + state = make_genesis_state(num_validators=3, genesis_time=1000) + + # Verify the header has zero state root + assert state.latest_block_header.state_root == Bytes32.zero() + + # Act + anchor_block = create_anchor_block(state) + + # Assert: State root should be computed from the state + expected_state_root = hash_tree_root(state) + assert anchor_block.state_root == expected_state_root + assert anchor_block.state_root != Bytes32.zero() + + def test_preserves_non_zero_state_root(self) -> None: + """Non-zero state root in header is preserved.""" + # Arrange: Create a state and process a slot to fill in state root + state = make_genesis_state(num_validators=3, genesis_time=1000) + # Process slot advances and fills in the state root + state_with_root = state.process_slots(Slot(1)) + + # The state root should now be non-zero in the header + assert state_with_root.latest_block_header.state_root != Bytes32.zero() + + # Act + anchor_block = create_anchor_block(state_with_root) + + # Assert: State root is preserved from the header + assert anchor_block.state_root == state_with_root.latest_block_header.state_root + + def test_preserves_header_fields(self) -> None: + """Slot, proposer_index, and parent_root are preserved from header.""" + # Arrange + state = make_genesis_state(num_validators=3, genesis_time=1000) + header = state.latest_block_header + + # Act + anchor_block = create_anchor_block(state) + + # Assert: Core header fields are preserved + assert anchor_block.slot == header.slot + assert anchor_block.proposer_index == header.proposer_index + assert anchor_block.parent_root == header.parent_root + + def test_creates_empty_body(self) -> None: + """Block body contains empty attestations list.""" + # Arrange + state = make_genesis_state(num_validators=3, genesis_time=1000) + + # Act + anchor_block = create_anchor_block(state) + + # Assert: Body has empty attestations + assert len(anchor_block.body.attestations) == 0 + + def test_anchor_block_structure_is_valid(self) -> None: + """Anchor block has all required fields populated.""" + # Arrange + state = make_genesis_state(num_validators=5, genesis_time=2000) + + # Act + anchor_block = create_anchor_block(state) + + # Assert: Block has proper structure + assert isinstance(anchor_block, Block) + assert isinstance(anchor_block.slot, Slot) + assert isinstance(anchor_block.proposer_index, Uint64) + assert isinstance(anchor_block.parent_root, Bytes32) + assert isinstance(anchor_block.state_root, Bytes32) + assert isinstance(anchor_block.body, BlockBody) + assert isinstance(anchor_block.body.attestations, AggregatedAttestations) + + +class TestInitFromCheckpoint: + """Tests for _init_from_checkpoint() async function.""" + + def test_checkpoint_sync_genesis_time_mismatch_returns_none(self) -> None: + """Returns None when checkpoint state genesis time differs from local config.""" + + async def run_test() -> None: + from lean_spec.__main__ import _init_from_checkpoint + from lean_spec.subspecs.genesis import GenesisConfig + + # Arrange: Create checkpoint state with genesis_time=1000 + checkpoint_state = make_genesis_state(num_validators=3, genesis_time=1000) + + # Local genesis config with different genesis_time=2000 + local_genesis = GenesisConfig.model_validate( + { + "GENESIS_TIME": 2000, + "GENESIS_VALIDATORS": [], + } + ) + + # Mock the checkpoint sync client functions + mock_event_source = AsyncMock() + + with ( + patch( + "lean_spec.subspecs.api.client.fetch_finalized_state", + new_callable=AsyncMock, + return_value=checkpoint_state, + ), + patch( + "lean_spec.subspecs.api.client.verify_checkpoint_state", + new_callable=AsyncMock, + return_value=True, + ), + ): + # Act + result = await _init_from_checkpoint( + checkpoint_sync_url="http://localhost:5052", + genesis=local_genesis, + event_source=mock_event_source, + ) + + # Assert: Returns None due to genesis time mismatch + assert result is None + + asyncio.run(run_test()) + + def test_checkpoint_sync_verification_failure_returns_none(self) -> None: + """Returns None when checkpoint state verification fails.""" + + async def run_test() -> None: + from lean_spec.__main__ import _init_from_checkpoint + from lean_spec.subspecs.genesis import GenesisConfig + + # Arrange + checkpoint_state = make_genesis_state(num_validators=3, genesis_time=1000) + local_genesis = GenesisConfig.model_validate( + { + "GENESIS_TIME": 1000, + "GENESIS_VALIDATORS": [], + } + ) + + mock_event_source = AsyncMock() + + with ( + patch( + "lean_spec.subspecs.api.client.fetch_finalized_state", + new_callable=AsyncMock, + return_value=checkpoint_state, + ), + patch( + "lean_spec.subspecs.api.client.verify_checkpoint_state", + new_callable=AsyncMock, + return_value=False, # Verification fails + ), + ): + # Act + result = await _init_from_checkpoint( + checkpoint_sync_url="http://localhost:5052", + genesis=local_genesis, + event_source=mock_event_source, + ) + + # Assert + assert result is None + + asyncio.run(run_test()) + + def test_checkpoint_sync_network_error_returns_none(self) -> None: + """Returns None when network error occurs during fetch.""" + + async def run_test() -> None: + from lean_spec.__main__ import _init_from_checkpoint + from lean_spec.subspecs.api.client import CheckpointSyncError + from lean_spec.subspecs.genesis import GenesisConfig + + # Arrange + local_genesis = GenesisConfig.model_validate( + { + "GENESIS_TIME": 1000, + "GENESIS_VALIDATORS": [], + } + ) + + mock_event_source = AsyncMock() + + with patch( + "lean_spec.subspecs.api.client.fetch_finalized_state", + new_callable=AsyncMock, + side_effect=CheckpointSyncError("Network error: connection refused"), + ): + # Act + result = await _init_from_checkpoint( + checkpoint_sync_url="http://localhost:5052", + genesis=local_genesis, + event_source=mock_event_source, + ) + + # Assert + assert result is None + + asyncio.run(run_test()) + + def test_checkpoint_sync_success_returns_node(self) -> None: + """Successful checkpoint sync returns initialized Node.""" + + async def run_test() -> None: + from lean_spec.__main__ import _init_from_checkpoint + from lean_spec.subspecs.genesis import GenesisConfig + from lean_spec.subspecs.node import Node + + # Arrange: Create matching genesis times + genesis_time = 1000 + checkpoint_state = make_genesis_state(num_validators=3, genesis_time=genesis_time) + + local_genesis = GenesisConfig.model_validate( + { + "GENESIS_TIME": genesis_time, + "GENESIS_VALIDATORS": [], + } + ) + + # Create a mock event source with required attributes + mock_event_source = AsyncMock() + mock_event_source.reqresp_client = AsyncMock() + + with ( + patch( + "lean_spec.subspecs.api.client.fetch_finalized_state", + new_callable=AsyncMock, + return_value=checkpoint_state, + ), + patch( + "lean_spec.subspecs.api.client.verify_checkpoint_state", + new_callable=AsyncMock, + return_value=True, + ), + ): + # Act + result = await _init_from_checkpoint( + checkpoint_sync_url="http://localhost:5052", + genesis=local_genesis, + event_source=mock_event_source, + ) + + # Assert: Returns a Node instance + assert result is not None + assert isinstance(result, Node) + + # Verify the node's store was initialized + assert result.store is not None + + asyncio.run(run_test()) + + def test_checkpoint_sync_http_status_error_returns_none(self) -> None: + """Returns None when HTTP status error occurs.""" + + async def run_test() -> None: + from lean_spec.__main__ import _init_from_checkpoint + from lean_spec.subspecs.api.client import CheckpointSyncError + from lean_spec.subspecs.genesis import GenesisConfig + + # Arrange + local_genesis = GenesisConfig.model_validate( + { + "GENESIS_TIME": 1000, + "GENESIS_VALIDATORS": [], + } + ) + + mock_event_source = AsyncMock() + + with patch( + "lean_spec.subspecs.api.client.fetch_finalized_state", + new_callable=AsyncMock, + side_effect=CheckpointSyncError("HTTP error 404: Not Found"), + ): + # Act + result = await _init_from_checkpoint( + checkpoint_sync_url="http://localhost:5052", + genesis=local_genesis, + event_source=mock_event_source, + ) + + # Assert + assert result is None + + asyncio.run(run_test())