Skip to content

Commit

Permalink
E2E テストを追加する準備
Browse files Browse the repository at this point in the history
  • Loading branch information
voluntas committed Sep 2, 2024
1 parent 0017cac commit 15da852
Show file tree
Hide file tree
Showing 5 changed files with 517 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ dev-dependencies = [
"auditwheel~=6.0.0",
"ruff>=0.6",
"typing-extensions>=4.12.2",
"pytest>=8.3.2",
"numpy>=2.1.0",
]

# examples でローカルの sora_sdk を利用したい場合は以下を有効にする
Expand Down
10 changes: 10 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
-e file:.
auditwheel==6.0.0
build==1.2.1
exceptiongroup==1.2.2
# via pytest
iniconfig==2.0.0
# via pytest
nanobind==2.1.0
numpy==2.1.0
packaging==24.1
# via auditwheel
# via build
# via pytest
pluggy==1.5.0
# via pytest
pyelftools==0.31
# via auditwheel
pyproject-hooks==1.1.0
# via build
pytest==8.3.2
ruff==0.6.1
setuptools==72.2.0
tomli==2.0.1
# via build
# via pytest
typing-extensions==4.12.2
wheel==0.43.0
257 changes: 257 additions & 0 deletions test/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import json
import queue
import threading
import time
from threading import Event
from typing import Any, Optional

import numpy

from sora_sdk import (
Sora,
SoraAudioSink,
SoraConnection,
SoraMediaTrack,
SoraSignalingErrorCode,
SoraVideoFrame,
SoraVideoSink,
)


class Sendonly:
def __init__(
self,
signaling_urls: list[str],
channel_id: str,
metadata: Optional[dict[str, Any]] = None,
audio: Optional[bool] = None,
video: Optional[bool] = None,
video_codec_type: Optional[str] = None,
video_bit_rate: Optional[int] = None,
data_channel_signaling: Optional[bool] = None,
openh264_path: Optional[str] = None,
use_hwa: bool = False,
audio_channels: int = 1,
audio_sample_rate: int = 16000,
):
self._signaling_urls: list[str] = signaling_urls
self._channel_id: str = channel_id

self._audio_channels: int = audio_channels
self._audio_sample_rate: int = audio_sample_rate

self._sora: Sora = Sora(openh264=openh264_path, use_hardware_encoder=use_hwa)

self._fake_audio_thread: Optional[threading.Thread] = None
self._fake_video_thread: Optional[threading.Thread] = None

self._audio_source = self._sora.create_audio_source(
self._audio_channels, self._audio_sample_rate
)
self._video_source = self._sora.create_video_source()

self._connection: SoraConnection = self._sora.create_connection(
signaling_urls=signaling_urls,
role="sendonly",
channel_id=channel_id,
metadata=metadata,
audio=audio,
video=video,
video_codec_type=video_codec_type,
video_bit_rate=video_bit_rate,
data_channel_signaling=data_channel_signaling,
audio_source=self._audio_source,
video_source=self._video_source,
)
self._connection_id: Optional[str] = None

self._connected: Event = Event()
self._switched: bool = False
self._closed: Event = Event()
self._default_connection_timeout_s: float = 10.0

self._connection.on_set_offer = self._on_set_offer
self._connection.on_switched = self._on_switched
self._connection.on_notify = self._on_notify
self._connection.on_disconnect = self._on_disconnect

def connect(self, fake_audio=False, fake_video=False) -> None:
self._connection.connect()

if fake_audio:
self._fake_audio_thread = threading.Thread(target=self._fake_audio_loop, daemon=True)
self._fake_audio_thread.start()

if fake_video:
self._fake_video_thread = threading.Thread(target=self._fake_video_loop, daemon=True)
self._fake_video_thread.start()

assert self._connected.wait(
self._default_connection_timeout_s
), "Could not connect to Sora."

def disconnect(self) -> None:
self._connection.disconnect()

def get_stats(self):
raw_stats = self._connection.get_stats()
return json.loads(raw_stats)

@property
def connected(self) -> bool:
return self._connected.is_set()

@property
def switched(self) -> bool:
return self._switched

def _fake_audio_loop(self):
while not self._closed.is_set():
time.sleep(0.02)
self._audio_source.on_data(numpy.zeros((320, 1), dtype=numpy.int16))

def _fake_video_loop(self):
while not self._closed.is_set():
time.sleep(1.0 / 30)
self._video_source.on_captured(numpy.zeros((480, 640, 3), dtype=numpy.uint8))

def _on_set_offer(self, raw_message: str) -> None:
message: dict[str, Any] = json.loads(raw_message)
if message["type"] == "offer":
self._connection_id = message["connection_id"]

def _on_switched(self, raw_message: str) -> None:
message = json.loads(raw_message)
if message["type"] == "switched":
print(f"Switched to DataChannel Signaling: connection_id={self._connection_id}")
self._switched = True

def _on_notify(self, raw_message: str) -> None:
message: dict[str, Any] = json.loads(raw_message)
if (
message["type"] == "notify"
and message["event_type"] == "connection.created"
and message["connection_id"] == self._connection_id
):
print(f"Connected Sora: connection_id={self._connection_id}")
self._connected.set()

def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str) -> None:
print(f"Disconnected Sora: error_code='{error_code}' message='{message}'")
self._connected.clear()
self._closed.set()

if self._fake_audio_thread is not None:
self._fake_audio_thread.join(timeout=10)

if self._fake_video_thread is not None:
self._fake_video_thread.join(timeout=10)


class Recvonly:
"""Sora からビデオと音声ストリームを受信するためのクラス。"""

def __init__(
self,
signaling_urls: list[str],
channel_id: str,
metadata: Optional[dict[str, Any]] = None,
data_channel_signaling: Optional[bool] = None,
openh264_path: Optional[str] = None,
use_hwa: Optional[bool] = False,
output_frequency: int = 16000,
output_channels: int = 1,
):
self._signaling_urls: list[str] = signaling_urls
self._channel_id: str = channel_id

self._output_frequency: int = output_frequency
self._output_channels: int = output_channels

self._sora: Sora = Sora(openh264=openh264_path, use_hardware_encoder=use_hwa)
self._connection: SoraConnection = self._sora.create_connection(
signaling_urls=signaling_urls,
role="recvonly",
channel_id=channel_id,
metadata=metadata,
data_channel_signaling=data_channel_signaling,
)
self._connection_id: Optional[str] = None

self._connected: Event = Event()
self._switched: bool = False
self._closed: Event = Event()
self._default_connection_timeout_s: float = 10.0

self._audio_sink: Optional[SoraAudioSink] = None
self._video_sink: Optional[SoraVideoSink] = None

self._q_out: queue.Queue = queue.Queue()

self._connection.on_set_offer = self._on_set_offer
self._connection.on_switched = self._on_switched
self._connection.on_notify = self._on_notify
self._connection.on_disconnect = self._on_disconnect
self._connection.on_track = self._on_track

def connect(self) -> None:
self._connection.connect()

assert self._connected.wait(
self._default_connection_timeout_s
), "Could not connect to Sora."

def disconnect(self) -> None:
self._connection.disconnect()

def get_stats(self):
raw_stats = self._connection.get_stats()
return json.loads(raw_stats)

@property
def connected(self) -> bool:
return self._connected.is_set()

@property
def switched(self) -> bool:
return self._switched

@property
def closed(self):
return self._closed.is_set()

def _on_set_offer(self, raw_message: str) -> None:
message: dict[str, Any] = json.loads(raw_message)
if message["type"] == "offer":
self._connection_id = message["connection_id"]

def _on_switched(self, raw_message: str) -> None:
message = json.loads(raw_message)
if message["type"] == "switched":
print(f"Switched to DataChannel Signaling: connection_id={self._connection_id}")
self._switched = True

def _on_notify(self, raw_message: str) -> None:
message: dict[str, Any] = json.loads(raw_message)
if (
message["type"] == "notify"
and message["event_type"] == "connection.created"
and message["connection_id"] == self._connection_id
):
print(f"Connected Sora: connection_id={self._connection_id}")
self._connected.set()

def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str) -> None:
print(f"Disconnected Sora: error_code='{error_code}' message='{message}'")
self._connected.clear()
self._closed.is_set()

def _on_video_frame(self, frame: SoraVideoFrame) -> None:
self._q_out.put(frame)

def _on_track(self, track: SoraMediaTrack) -> None:
if track.kind == "audio":
self._audio_sink = SoraAudioSink(track, self._output_frequency, self._output_channels)
if track.kind == "video":
self._video_sink = SoraVideoSink(track)
self._video_sink.on_frame = self._on_video_frame
32 changes: 32 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

import pytest
from dotenv import load_dotenv


@pytest.fixture
def setup():
# 環境変数読み込み
load_dotenv()

# signaling_url 単体か複数かをランダムで決めてテストする
if (test_signaling_urls := os.environ.get("TEST_SIGNALING_URLS")) is None:
raise ValueError("TEST_SIGNALING_URLS is required.")

# , で区切って ['wss://...', ...] に変換
test_signaling_urls = test_signaling_urls.split(",")

if (test_channel_id_prefix := os.environ.get("TEST_CHANNEL_ID_PREFIX")) is None:
raise ValueError("TEST_CHANNEL_ID_PREFIX is required.")

if (test_secret_key := os.environ.get("TEST_SECRET_KEY")) is None:
raise ValueError("TEST_SECRET_KEY is required.")

return {
"signaling_urls": test_signaling_urls,
"channel_id_prefix": test_channel_id_prefix,
"secret": test_secret_key,
"metadata": {"access_token": test_secret_key},
# openh264_path は str | None でよい
"openh264_path": os.environ.get("OPENH264_PATH"),
}
Loading

0 comments on commit 15da852

Please sign in to comment.