Source code for netrl.channels.comm_channel

"""
comm_channel.py
===============
Defines the CommChannel abstract base class and two concrete implementations:

  GEChannel      — Gilbert-Elliott channel backed by the C++ pybind11 extension.
  PerfectChannel — Lossless, zero-delay channel for baselines and unit tests.

The ABC is the extensibility seam for future channel backends (e.g. ns3).
To plug in a new backend, subclass CommChannel and implement the four abstract
methods, then pass `channel_factory=YourChannel` to NetworkedEnv or CentralNode.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import numpy as np

from netrl.channels.network_config import NetworkConfig


[docs] class CommChannel(ABC): """ Abstract interface for a communication channel simulation. Contract -------- - transmit(obs, step) is called exactly once per env.step() with the raw observation produced by the wrapped environment and the current integer step counter. - flush(step) is called exactly once per env.step() (after transmit) and returns all packets whose scheduled arrival_step <= step. - reset() is called on env.reset(); must clear all pending packets and any internal channel state. - get_channel_info() returns a diagnostic dict for logging; minimum keys: {"state": str, "pending_count": int}. Fixed-delay channels (GEChannel) guarantee at most one packet returned by flush() per step. Variable-delay channels may return more. """
[docs] @abstractmethod def transmit(self, obs: np.ndarray, step: int, packet_size: Optional[int] = None) -> None: """ Simulate transmission of `obs` at integer step `step`. The channel decides whether the packet is lost and, if not, computes a delivery step (>= step) and queues the packet internally. Parameters ---------- obs : np.ndarray Raw observation from the wrapped env. step : int Current integer step counter (0-indexed). packet_size : int | None Payload size in bytes for this packet. None means use the channel's default. Channels that do not model packet-size effects (GE, Perfect) silently ignore it. """
[docs] @abstractmethod def flush(self, step: int) -> List[Tuple[int, np.ndarray]]: """ Return all packets whose arrival_step <= step. Each element is a tuple (arrival_step: int, obs: np.ndarray). Returns an empty list when no packet is due. Parameters ---------- step : int Current integer step counter. """
[docs] @abstractmethod def reset(self) -> None: """ Clear pending packets and reset internal state. Called on env.reset(). Must NOT re-seed the RNG. """
[docs] @abstractmethod def get_channel_info(self) -> dict: """ Return a diagnostic dict for logging or debugging. Minimum keys: {"state": str, "pending_count": int}. """
# ---------------------------------------------------------------------------
[docs] class GEChannel(CommChannel): """ Gilbert-Elliott channel backed by the C++ pybind11 extension `netcomm`. The Markov chain state (Good/Bad), RNG, and pending packet queue all live inside the C++ GEChannelImpl object for atomicity and reproducibility. Parameters ---------- config : NetworkConfig Channel and buffer configuration. delay_steps, p_gb, p_bg, loss_good, loss_bad, and seed are forwarded to the C++ backend. Raises ------ ImportError If the netcomm C++ extension has not been built. Run `pip install -e .` or `python setup.py build_ext --inplace`. """
[docs] def __init__(self, config: NetworkConfig) -> None: try: import netcomm # C++ pybind11 extension except ImportError as exc: raise ImportError( "netcomm C++ extension not found. " "Run `pip install -e .` or " "`python setup.py build_ext --inplace`." ) from exc self._impl = netcomm.GEChannelImpl( p_gb=config.p_gb, p_bg=config.p_bg, loss_good=config.loss_good, loss_bad=config.loss_bad, delay_steps=config.delay_steps, seed=config.seed, )
[docs] def transmit(self, obs: np.ndarray, step: int, packet_size: Optional[int] = None) -> None: # packet_size has no effect on the GE channel model; ignored. self._impl.transmit(np.ascontiguousarray(obs, dtype=np.float64), step)
[docs] def flush(self, step: int) -> List[Tuple[int, np.ndarray]]: return self._impl.flush(step)
[docs] def reset(self) -> None: self._impl.reset()
[docs] def get_channel_info(self) -> dict: return dict(self._impl.get_channel_info())
# ---------------------------------------------------------------------------
[docs] class PerfectChannel(CommChannel): """ Lossless, zero-delay channel for debugging and baselines. Does not require the C++ extension. Every transmitted packet is immediately available at the same step via flush(). """
[docs] def __init__(self, config: NetworkConfig | None = None) -> None: # config is accepted but ignored; included for API compatibility # with channel_factory(config) call signature. self._pending: List[Tuple[int, np.ndarray]] = []
[docs] def transmit(self, obs: np.ndarray, step: int, packet_size: Optional[int] = None) -> None: # packet_size has no effect on the perfect channel; ignored. self._pending.append((step, obs.copy()))
[docs] def flush(self, step: int) -> List[Tuple[int, np.ndarray]]: due = [(s, o) for s, o in self._pending if s <= step] self._pending = [(s, o) for s, o in self._pending if s > step] return due
[docs] def reset(self) -> None: self._pending.clear()
[docs] def get_channel_info(self) -> dict: return { "state": "PERFECT", "pending_count": len(self._pending), }