Source code for netrl.utils.multi_view_model

from gymnasium import Env, spaces
import numpy as np
import gymnasium as gym

from typing import List

[docs] class MultiViewModel:
[docs] def __init__(self, observer_ids: List[str], obs_shapes: List[tuple], obs_dtypes: List[np.dtype] ): self.observer_ids = observer_ids self.obs_shapes = obs_shapes self.obs_dtypes = obs_dtypes self.spaces = { oid: spaces.Box( low=-np.inf, high=np.inf, shape=shape, dtype=dtype, ) for oid, shape, dtype in zip(observer_ids, obs_shapes, obs_dtypes) }
[docs] def observe(self, env: gym.Env = None, state = None): """ Class to implement the observations of each node """ raise NotImplementedError("This method should be implemented by the user to return the observations for each observer.")