-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
101 lines (80 loc) · 3.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import asyncio
import threading
from dataclasses import dataclass
from fastapi import FastAPI, Request, WebSocket
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.websockets import WebSocketDisconnect, WebSocketState
@dataclass
class Connection:
text: str
websockets: dict[int, WebSocket]
class WebSocketBroadcaster:
def __init__(self):
self.connections: dict[str, Connection] = {}
self.lock = threading.Lock()
def conn_count(self, path: str):
return len(self.connections[path].websockets)
async def broadcast(self, path: str, sender_id: int, message: str):
with self.lock:
conn = self.connections[path]
conn.text = message if message else ""
async def send_text(id_: int):
if id_ == sender_id:
return
websocket = conn.websockets[id_]
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_text(conn.text)
asyncio.gather(*map(send_text, conn.websockets))
async def register(self, path: str, websocket: WebSocket):
with self.lock:
if path not in self.connections:
self.connections[path] = Connection("", {id(websocket): websocket})
else:
conn = self.connections[path]
conn.websockets[id(websocket)] = websocket
await websocket.send_text(conn.text)
async def unregister(self, path: str, websocket: WebSocket):
with self.lock:
conn = self.connections[path]
del conn.websockets[id(websocket)]
if self.conn_count(path) == 0:
conn.text = ""
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(".")
ws_broadcaster = WebSocketBroadcaster()
@app.get("/")
async def index():
return RedirectResponse("/default")
@app.get("/{path}")
async def get(request: Request, path: str = ""):
return templates.TemplateResponse(
request,
"index.html",
context={
"ws_message": app.url_path_for("ws_message", path=path),
"ws_heartbeat": app.url_path_for("ws_heartbeat", path=path),
},
)
@app.websocket("/ws/input/{path}")
async def ws_message(websocket: WebSocket, path: str):
await websocket.accept()
await ws_broadcaster.register(path, websocket)
try:
while True:
text = await websocket.receive_text()
await ws_broadcaster.broadcast(path, id(websocket), text)
except WebSocketDisconnect:
await ws_broadcaster.unregister(path, websocket)
return
@app.websocket("/ws/heartbeat/{path}")
async def ws_heartbeat(websocket: WebSocket, path: str):
await websocket.accept()
try:
while True:
await websocket.receive_text()
await websocket.send_text(str(ws_broadcaster.conn_count(path)))
except WebSocketDisconnect:
return