CentralNode
CentralNode owns one CommChannel and one
ObservationBuffer per registered node. It is the aggregation
layer used internally by both environment wrappers, and can also be used
directly for custom multi-agent pipelines.
- class netrl.CentralNode(node_ids, obs_shape, obs_dtype, config, channel_factory)[source]
Bases:
objectAggregator that owns one CommChannel and one ObservationBuffer per node.
- Parameters:
node_ids (List[str]) – Unique string identifiers for each distributed node (agent).
obs_shape (tuple | List[tuple]) – Shape of a single observation. Pass a single
tuple(e.g.(4,)) to use the same shape for every node, or aList[tuple](one entry per node) to give each node its own observation shape.obs_dtype (dtype | List[dtype]) – NumPy dtype of observations. Same broadcast rules as
obs_shape: a single value is applied to all nodes; a list assigns per-node.config (NetworkConfig) – Channel + buffer configuration shared across all nodes. Each node gets a copy with
seed = config.seed + node_indexso the per-node RNGs are independent.channel_factory (Callable[[NetworkConfig], CommChannel]) – Callable that takes a NetworkConfig and returns a CommChannel. Swap for PerfectChannel, NS3WifiChannel, or any custom channel without changing this class.
- __init__(node_ids, obs_shape, obs_dtype, config, channel_factory)[source]
- Parameters:
config (NetworkConfig)
channel_factory (Callable[[NetworkConfig], CommChannel])
- Return type:
None
- receive_from(node_id, obs, step, packet_size=None)[source]
Transmit obs from node_id through its channel.
- Parameters:
node_id (str Must match one of the ids given at construction.)
obs (np.ndarray Raw local observation to be transmitted.)
step (int Current integer step counter.)
packet_size (int | None Payload bytes for this packet. None means) – use the channel’s own default.
- Return type:
None
- flush_and_update(step)[source]
Flush all channels for step and update each observation buffer.
- For each node:
flush its channel to retrieve all packets due at this step,
add ALL packets to the buffer with their correct observation times
When a packet arrives at step S with delay_steps=D, the observation it contains is from time step (S - D), so we add it with that step number.
With GEChannel (fixed delay): typically 0-1 packet per step With NS3WifiChannel (variable delay): can be 0-N packets per step due to retransmissions and variable latencies; we add all of them.
The arrived_map records the last packet that arrived for each node (for backward compatibility and info reporting).
- Returns:
The last observation that arrived for each node this step, or None.
- Return type:
Dict[node_id -> obs | None]
- Parameters:
step (int)
- get_buffer(node_id)[source]
Return the padded (obs_array, recv_mask) for node_id.
- Shapes:
obs_array : (buffer_size,
*obs_shape) recv_mask : (buffer_size,) — dtype bool
The most recent entry is at index [-1]; older entries are to the left. Unwritten slots are zeros with recv_mask = False.
- property config: NetworkConfig
Direct usage example
import numpy as np
import gymnasium as gym
from netrl import CentralNode, NetworkConfig
from netrl.channels.comm_channel import GEChannel
central = CentralNode(
node_ids=["agent_0", "agent_1"],
obs_shape=(4,), # single tuple → same shape for all nodes
obs_dtype=np.float32,
config=NetworkConfig(buffer_size=10, seed=42),
channel_factory=GEChannel,
)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
central.reset()
for step in range(200):
central.receive_from("agent_0", obs, step)
central.receive_from("agent_1", obs, step, packet_size=256)
arrived = central.flush_and_update(step)
obs, _, term, trunc, _ = env.step(env.action_space.sample())
if term or trunc:
obs, _ = env.reset()
buf_0, mask_0 = central.get_buffer("agent_0")
# buf_0.shape == (10, 4), mask_0.shape == (10,)
Per-node observation shapes
Pass a list to give each node its own observation shape:
central = CentralNode(
node_ids=["lidar", "camera"],
obs_shape=[(2,), (8,)], # list → per-node shapes
obs_dtype=[np.float32, np.float32],
config=NetworkConfig(buffer_size=8),
channel_factory=GEChannel,
)