CentralNode

CentralNode owns one CommChannel and one ObservationBuffer per registered node. It is the aggregation layer used internally by both environment wrappers, and can also be used directly for custom multi-agent pipelines.

class netrl.CentralNode(node_ids, obs_shape, obs_dtype, config, channel_factory)[source]

Bases: object

Aggregator that owns one CommChannel and one ObservationBuffer per node.

Parameters:
  • node_ids (List[str]) – Unique string identifiers for each distributed node (agent).

  • obs_shape (tuple | List[tuple]) – Shape of a single observation. Pass a single tuple (e.g. (4,)) to use the same shape for every node, or a List[tuple] (one entry per node) to give each node its own observation shape.

  • obs_dtype (dtype | List[dtype]) – NumPy dtype of observations. Same broadcast rules as obs_shape: a single value is applied to all nodes; a list assigns per-node.

  • config (NetworkConfig) – Channel + buffer configuration shared across all nodes. Each node gets a copy with seed = config.seed + node_index so the per-node RNGs are independent.

  • channel_factory (Callable[[NetworkConfig], CommChannel]) – Callable that takes a NetworkConfig and returns a CommChannel. Swap for PerfectChannel, NS3WifiChannel, or any custom channel without changing this class.

__init__(node_ids, obs_shape, obs_dtype, config, channel_factory)[source]
Parameters:
Return type:

None

receive_from(node_id, obs, step, packet_size=None)[source]

Transmit obs from node_id through its channel.

Parameters:
  • node_id (str Must match one of the ids given at construction.)

  • obs (np.ndarray Raw local observation to be transmitted.)

  • step (int Current integer step counter.)

  • packet_size (int | None Payload bytes for this packet. None means) – use the channel’s own default.

Return type:

None

flush_and_update(step)[source]

Flush all channels for step and update each observation buffer.

For each node:
  • flush its channel to retrieve all packets due at this step,

  • add ALL packets to the buffer with their correct observation times

When a packet arrives at step S with delay_steps=D, the observation it contains is from time step (S - D), so we add it with that step number.

With GEChannel (fixed delay): typically 0-1 packet per step With NS3WifiChannel (variable delay): can be 0-N packets per step due to retransmissions and variable latencies; we add all of them.

The arrived_map records the last packet that arrived for each node (for backward compatibility and info reporting).

Returns:

The last observation that arrived for each node this step, or None.

Return type:

Dict[node_id -> obs | None]

Parameters:

step (int)

get_buffer(node_id)[source]

Return the padded (obs_array, recv_mask) for node_id.

Shapes:

obs_array : (buffer_size, *obs_shape) recv_mask : (buffer_size,) — dtype bool

The most recent entry is at index [-1]; older entries are to the left. Unwritten slots are zeros with recv_mask = False.

Parameters:

node_id (str)

Return type:

Tuple[ndarray, ndarray]

get_all_buffers()[source]

Return get_buffer() for every registered node.

Return type:

Dict[str, Tuple[ndarray, ndarray]]

get_channel_info(node_id)[source]

Return diagnostic channel state dict for node_id.

Parameters:

node_id (str)

Return type:

dict

reset()[source]

Reset all channels and buffers. Call on env.reset().

Return type:

None

property node_ids: List[str]
property config: NetworkConfig

Direct usage example

import numpy as np
import gymnasium as gym
from netrl import CentralNode, NetworkConfig
from netrl.channels.comm_channel import GEChannel

central = CentralNode(
    node_ids=["agent_0", "agent_1"],
    obs_shape=(4,),          # single tuple → same shape for all nodes
    obs_dtype=np.float32,
    config=NetworkConfig(buffer_size=10, seed=42),
    channel_factory=GEChannel,
)

env = gym.make("CartPole-v1")
obs, _ = env.reset()
central.reset()

for step in range(200):
    central.receive_from("agent_0", obs, step)
    central.receive_from("agent_1", obs, step, packet_size=256)
    arrived = central.flush_and_update(step)
    obs, _, term, trunc, _ = env.step(env.action_space.sample())
    if term or trunc:
        obs, _ = env.reset()

buf_0, mask_0 = central.get_buffer("agent_0")
# buf_0.shape == (10, 4),  mask_0.shape == (10,)

Per-node observation shapes

Pass a list to give each node its own observation shape:

central = CentralNode(
    node_ids=["lidar", "camera"],
    obs_shape=[(2,), (8,)],    # list → per-node shapes
    obs_dtype=[np.float32, np.float32],
    config=NetworkConfig(buffer_size=8),
    channel_factory=GEChannel,
)