Skip to content

Commit

Permalink
[JSSOCKET]: Unit Test
Browse files Browse the repository at this point in the history
  • Loading branch information
amadolid committed Jan 9, 2024
1 parent 8c8c4c7 commit 4a8ca4c
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 32 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/jaseci-serv-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ jobs:
- name: Verify installation
run: |
jsctl
- name: Install jaseci_socket
run: |
pip3 install jaseci_socket/
jssocket -p 8002 & sleep 2
python -m websockets ws://localhost:8002/ws
- name: Install jaseci_serv and run tests
if: always()
run: |
Expand Down
36 changes: 14 additions & 22 deletions jaseci_serv/jaseci_serv/svc/socket_svc.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
from jaseci.utils.utils import logger
from jaseci.jsorc.jsorc import JsOrc
from jaseci.jsorc.jsorc_utils import ManifestType
from jaseci.extens.svc.socket_svc import SocketService as Ss

try:
from hmac import compare_digest
except ImportError:

def compare_digest(a, b):
return a == b


import binascii
from knox.crypto import hash_token
from knox.models import AuthToken
from knox.settings import CONSTANTS
from django.db import connection
from django.utils import timezone
from knox.auth import TokenAuthentication
from websocket import WebSocketApp as wsa

#################################################
Expand All @@ -31,6 +20,16 @@ def compare_digest(a, b):
pre_loaded=True,
)
class SocketService(Ss):
def __init__(
self,
config: dict,
manifest: dict,
manifest_type: ManifestType = ManifestType.DEDICATED,
source: dict = {},
):
self.authenticator = TokenAuthentication()
super().__init__(config, manifest, manifest_type, source)

def client_connect(self, ws: wsa, data: dict):
user = "public"
token = data.pop("token", None)
Expand All @@ -50,15 +49,8 @@ def client_connect(self, ws: wsa, data: dict):

def authenticate(self, token: str):
try:
connection.connect()
for auth_token in AuthToken.objects.filter(
token_key=token[: CONSTANTS.TOKEN_KEY_LENGTH]
):
digest = hash_token(token)
if compare_digest(digest, auth_token.digest) and (
not auth_token.expiry or auth_token.expiry > timezone.now()
):
return auth_token.user
connection.ensure_connection()
return self.authenticator.authenticate_credentials(token.encode())[0]
except Exception:
logger.exception("Error authenticating socket!")
return None
87 changes: 87 additions & 0 deletions jaseci_serv/jaseci_serv/svc/tests/test_socket_svc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from jaseci.utils.utils import TestCaseHelper
from jaseci.jsorc.jsorc import JsOrc
from jaseci_serv.svc.socket_svc import SocketService

from orjson import dumps
from json import loads
from django.urls import reverse
from django.test import TestCase
from django.contrib.auth import get_user_model
from rest_framework.test import APIClient
from websocket import create_connection
from websocket import WebSocket


class SocketServiceTest(TestCaseHelper, TestCase):
def setUp(self):
super().setUp()
self.admin = get_user_model().objects.create_superuser(
"admin@jaseci.com", "password"
)

# since pytest uses in memory database we need to
# get the token first before opening socket service
self.token = (
APIClient()
.post(
reverse("user_api:token"),
{"email": "admin@jaseci.com", "password": "password"},
)
.data["token"]
)

JsOrc.settings("SOCKET_CONFIG").update(
{
"enabled": True,
"url": "ws://localhost:8002/ws",
"ping_url": "http://localhost:8002/healthz",
}
)

def socket_process(self, ws: WebSocket, token: str = None):
data = {}
if token:
data["token"] = token

ws.send(dumps({"type": "client_connect", "data": data}))
event: dict = loads(ws.recv())

self.assertEqual("client_connected", event.get("type"))
data: dict = event.get("data")
target = data.get("target")
self.assertTrue(data)
self.assertTrue(target)
self.assertTrue(data.get("authenticated") == bool(token))

ws.send(
dumps(
{
"type": "notify_client",
"data": {"target": target, "data": {"test": 1}},
}
)
)
event: dict = loads(ws.recv())
self.assertEqual({"test": 1}, event)

ws.send(
dumps(
{
"type": "notify_group",
"data": {"target": target, "data": {"test": 2}},
}
)
)
event: dict = loads(ws.recv())
self.assertEqual({"test": 2}, event)

@JsOrc.inject(services=["socket"])
def test_socket(self, socket: SocketService):
self.assertTrue(socket.is_running())

ws = create_connection("ws://localhost:8002/ws")

self.socket_process(ws)
self.socket_process(ws, self.token)

ws.close()
21 changes: 12 additions & 9 deletions jaseci_socket/jaseci_socket/jssocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class JsSocket:

_servers = dict[str, wssp]()
_clients = dict[str, wssp]()
_groups = dict[str, str]()
_groups = dict[str, set]()

_servers_queue = list()

Expand Down Expand Up @@ -126,14 +126,9 @@ async def server_connect(self, ws: wssp, data: dict):

async def client_connect(self, ws: wssp, data: dict):
ws_id = str(ws.id)
if not getattr(ws, "connected", None):
self._clients[ws_id] = ws
data["target"] = ws_id
await self.server_send({"type": "client_connect", "data": data})
else:
await self.client_send(
ws, {"type": "client_connected", "data": {"target": ws_id}}
)
self._clients[ws_id] = ws
data["target"] = ws_id
await self.server_send({"type": "client_connect", "data": data})
await self.cleanup()

async def client_disconnect(self, ws: wssp, data: dict):
Expand Down Expand Up @@ -162,6 +157,14 @@ async def client_connected(self, ws: wssp, data: dict):
self._groups[user] = set([ws_id])
else:
group.add(ws_id)

old_group = getattr(ws, "group", None)
if old_group:
group: set = self._groups.get(old_group)
if ws_id in group:
logging.info(f"Removing {ws_id} on group {old_group}")
group.remove(ws_id)

ws.group = user
ws.connected = True
await self.client_send(ws, {"type": "client_connected", "data": data})
Expand Down

0 comments on commit 4a8ca4c

Please sign in to comment.