from dataclasses import dataclass, field from typing import Any, Callable, TypeVar from socketio import AsyncServer T = TypeVar("T") @dataclass class SocketMemory: callbacks: dict[str, dict[str, Callable]] = field(default_factory=dict) room_data: dict[str, tuple[str, Any]] = field(default_factory=dict) class Socket: sio: AsyncServer from_sid: str | None sid: str | None mem: SocketMemory def __init__(self, sio: AsyncServer, mem: SocketMemory | None = None) -> None: self.sio = sio self.from_sid = None self.sid = None if mem is None: self.mem = SocketMemory() self.on("disconnect")(lambda: None) else: self.mem = mem def __str__(self) -> str: if self.sid is None: return "Server Socket" elif self.from_sid is None: return "Client Socket {id: " + self.sid + "}" else: return "Room Socket {id: " + self.sid + ", room_id: " + self.from_sid + "}" def clear(self, sid: str) -> None: for event in self.mem.callbacks: if sid in self.mem.callbacks[event]: self.mem.callbacks[event].pop(sid) for room in tuple(self.mem.room_data.keys()): owner = self.mem.room_data[room][0] if owner == sid: self.mem.room_data.pop(room) def to(self, sid: str) -> "Socket": socket = Socket(self.sio, self.mem) socket.from_sid = self.sid socket.sid = sid return socket def on(self, event: str): def wrapper(f: Callable): if event not in self.mem.callbacks: self.mem.callbacks[event] = {} @self.sio.on(event) # type: ignore async def handler(sid: str, *args): if sid in self.mem.callbacks[event]: await self.mem.callbacks[event][sid](*args) if event == "disconnect": self.clear(sid) if self.sid is not None: async def wrapper_inner(*args): try: await f(*args) except Exception as e: print("Error:", e) self.mem.callbacks[event][self.sid] = wrapper_inner return f return wrapper async def emit(self, event: str, *args): if self.from_sid is None: await self.sio.emit(event, args, to=self.sid) else: await self.sio.emit(event, args, room=self.sid, skip_sid=self.from_sid) async def join(self, room: str) -> "Socket": await self.sio.enter_room(self.sid, room) return self.to(room) async def leave(self, room: str): await self.sio.leave_room(self.sid, room) def room_exists(self, room: str): rooms = self.sio.manager.rooms.get("/", {}) return room in rooms and len(rooms[room]) > 0 def room_data(self, f: Callable[[], T]) -> T: if self.from_sid is None or self.sid is None: raise Exception(f"{self} cannot use room data") if self.sid in self.mem.room_data: data = self.mem.room_data[self.sid][1] return data else: data = f() self.mem.room_data[self.sid] = self.from_sid, data return data async def connect(sio: AsyncServer, sid: str): server = Socket(sio) socket = server.to(sid) print(f"Client #{sid} connected") await socket.emit("msg", "Test", 123, [1, True, None]) @socket.on("msg") async def test(s: str, n: int, l: list): print(s, n * 10, l[0])