124 lines
3.6 KiB
Python
124 lines
3.6 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, TypeVar
|
|
from socketio import AsyncServer
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@dataclass
|
|
class SocketMemory:
|
|
callbacks: dict[str, dict[str, Callable]] = field(default_factory=dict)
|
|
room_data: dict[str, tuple[str, Any]] = field(default_factory=dict)
|
|
|
|
|
|
class Socket:
|
|
sio: AsyncServer
|
|
from_sid: str | None
|
|
sid: str | None
|
|
mem: SocketMemory
|
|
|
|
def __init__(self, sio: AsyncServer, mem: SocketMemory | None = None) -> None:
|
|
self.sio = sio
|
|
self.from_sid = None
|
|
self.sid = None
|
|
if mem is None:
|
|
self.mem = SocketMemory()
|
|
self.on("disconnect")(lambda: None)
|
|
else:
|
|
self.mem = mem
|
|
|
|
def __str__(self) -> str:
|
|
if self.sid is None:
|
|
return "Server Socket"
|
|
elif self.from_sid is None:
|
|
return "Client Socket {id: " + self.sid + "}"
|
|
else:
|
|
return "Room Socket {id: " + self.sid + ", room_id: " + self.from_sid + "}"
|
|
|
|
def clear(self, sid: str) -> None:
|
|
for event in self.mem.callbacks:
|
|
if sid in self.mem.callbacks[event]:
|
|
self.mem.callbacks[event].pop(sid)
|
|
|
|
for room in tuple(self.mem.room_data.keys()):
|
|
owner = self.mem.room_data[room][0]
|
|
if owner == sid:
|
|
self.mem.room_data.pop(room)
|
|
|
|
def to(self, sid: str) -> "Socket":
|
|
socket = Socket(self.sio, self.mem)
|
|
socket.from_sid = self.sid
|
|
socket.sid = sid
|
|
|
|
return socket
|
|
|
|
def on(self, event: str):
|
|
def wrapper(f: Callable):
|
|
if event not in self.mem.callbacks:
|
|
self.mem.callbacks[event] = {}
|
|
|
|
@self.sio.on(event) # type: ignore
|
|
async def handler(sid: str, *args):
|
|
if sid in self.mem.callbacks[event]:
|
|
await self.mem.callbacks[event][sid](*args)
|
|
|
|
if event == "disconnect":
|
|
self.clear(sid)
|
|
|
|
if self.sid is not None:
|
|
|
|
async def wrapper_inner(*args):
|
|
try:
|
|
await f(*args)
|
|
except Exception as e:
|
|
print("Error:", e)
|
|
|
|
self.mem.callbacks[event][self.sid] = wrapper_inner
|
|
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
async def emit(self, event: str, *args):
|
|
if self.from_sid is None:
|
|
await self.sio.emit(event, args, to=self.sid)
|
|
else:
|
|
await self.sio.emit(event, args, room=self.sid, skip_sid=self.from_sid)
|
|
|
|
async def join(self, room: str) -> "Socket":
|
|
await self.sio.enter_room(self.sid, room)
|
|
|
|
return self.to(room)
|
|
|
|
async def leave(self, room: str):
|
|
await self.sio.leave_room(self.sid, room)
|
|
|
|
def room_exists(self, room: str):
|
|
rooms = self.sio.manager.rooms.get("/", {})
|
|
return room in rooms and len(rooms[room]) > 0
|
|
|
|
def room_data(self, f: Callable[[], T]) -> T:
|
|
if self.from_sid is None or self.sid is None:
|
|
raise Exception(f"{self} cannot use room data")
|
|
if self.sid in self.mem.room_data:
|
|
data = self.mem.room_data[self.sid][1]
|
|
return data
|
|
else:
|
|
data = f()
|
|
self.mem.room_data[self.sid] = self.from_sid, data
|
|
return data
|
|
|
|
|
|
async def connect(sio: AsyncServer, sid: str):
|
|
server = Socket(sio)
|
|
socket = server.to(sid)
|
|
|
|
print(f"Client #{sid} connected")
|
|
|
|
await socket.emit("msg", "Test", 123, [1, True, None])
|
|
|
|
@socket.on("msg")
|
|
async def test(s: str, n: int, l: list):
|
|
print(s, n * 10, l[0])
|