Skip to content

Commit

Permalink
Add typing (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
daveisfera authored Jan 21, 2024
1 parent 3faab2c commit 3e56d13
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 34 deletions.
3 changes: 2 additions & 1 deletion pyrtmp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

from bitstring import BitStream
from bitstring.bits import Bits
from bitstring.utils import tokenparser


Expand All @@ -23,7 +24,7 @@ def __init__(self, reader: StreamReader) -> None:
self.total_bytes = 0
super().__init__()

async def read(self, fmt):
async def read(self, fmt) -> int | float | str | Bits | bool | bytes | None:
_, token = tokenparser(fmt)
assert len(token) == 1
name, length, _ = token[0]
Expand Down
14 changes: 8 additions & 6 deletions pyrtmp/messages/handshake.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from bitstring import BitArray, BitStream

from pyrtmp import BitStreamReader
Expand All @@ -9,11 +11,11 @@ def __init__(self, protocol_version: int) -> None:
super().__init__()

@classmethod
async def from_stream(cls, stream: BitStreamReader):
async def from_stream(cls, stream: BitStreamReader) -> C0:
protocol_version = await stream.read("uint:8")
return cls(protocol_version=protocol_version)

def to_bytes(self):
def to_bytes(self) -> bytes:
stream = BitStream()
stream.append(BitArray(uint=self.protocol_version, length=8))
return stream.bytes
Expand All @@ -27,13 +29,13 @@ def __init__(self, time: int, zero: int, random: bytes) -> None:
super().__init__()

@classmethod
async def from_stream(cls, stream: BitStreamReader):
async def from_stream(cls, stream: BitStreamReader) -> C1:
time = await stream.read("uint:32")
zero = await stream.read("uint:32")
rand = await stream.read("bytes:1528")
return cls(time=time, zero=zero, random=rand)

def to_bytes(self):
def to_bytes(self) -> bytes:
stream = BitStream()
stream.append(BitArray(uint=self.time, length=32))
stream.append(BitArray(uint=self.zero, length=32))
Expand All @@ -49,13 +51,13 @@ def __init__(self, time1: int, time2: int, random: bytes) -> None:
super().__init__()

@classmethod
async def from_stream(cls, stream: BitStreamReader):
async def from_stream(cls, stream: BitStreamReader) -> C2:
time1 = await stream.read("uint:32")
time2 = await stream.read("uint:32")
rand = await stream.read("bytes:1528")
return cls(time1=time1, time2=time2, random=rand)

def to_bytes(self):
def to_bytes(self) -> bytes:
stream = BitStream()
stream.append(BitArray(uint=self.time1, length=32))
stream.append(BitArray(uint=self.time2, length=32))
Expand Down
40 changes: 20 additions & 20 deletions pyrtmp/rtmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ async def client_callback(self, reader: StreamReader, writer: StreamWriter) -> N

writer.close()

async def on_handshake(self, session) -> None:
async def on_handshake(self, session: SessionManager) -> None:
await session.handshake()

async def on_nc_connect(self, session, message) -> None:
async def on_nc_connect(self, session: SessionManager, message: NCConnect) -> None:
session.write_chunk_to_stream(WindowAcknowledgementSize(ack_window_size=5000000))
session.write_chunk_to_stream(SetPeerBandwidth(ack_window_size=5000000, limit_type=2))
session.write_chunk_to_stream(StreamBegin(stream_id=0))
Expand All @@ -128,37 +128,37 @@ async def on_nc_connect(self, session, message) -> None:
session.write_chunk_to_stream(message.create_response())
await session.drain()

async def on_window_acknowledgement_size(self, session, message) -> None:
async def on_window_acknowledgement_size(self, session: SessionManager, message: WindowAcknowledgementSize) -> None:
pass

async def on_nc_create_stream(self, session, message) -> None:
async def on_nc_create_stream(self, session: SessionManager, message: NCCreateStream) -> None:
session.write_chunk_to_stream(message.create_response())
await session.drain()

async def on_ns_publish(self, session, message) -> None:
async def on_ns_publish(self, session: SessionManager, message: NSPublish) -> None:
session.write_chunk_to_stream(StreamBegin(stream_id=1))
session.write_chunk_to_stream(message.create_response())
await session.drain()

async def on_metadata(self, session, message) -> None:
async def on_metadata(self, session: SessionManager, message: MetaDataMessage) -> None:
pass

async def on_set_chunk_size(self, session, message) -> None:
async def on_set_chunk_size(self, session: SessionManager, message: SetChunkSize) -> None:
session.reader_chunk_size = message.chunk_size

async def on_video_message(self, session, message) -> None:
async def on_video_message(self, session: SessionManager, message: VideoMessage) -> None:
pass

async def on_audio_message(self, session, message) -> None:
async def on_audio_message(self, session: SessionManager, message: AudioMessage) -> None:
pass

async def on_ns_close_stream(self, session, message) -> None:
async def on_ns_close_stream(self, session: SessionManager, message: NSCloseStream) -> None:
pass

async def on_ns_delete_stream(self, session, message) -> None:
async def on_ns_delete_stream(self, session: SessionManager, message: NSDeleteStream) -> None:
pass

async def on_unknown_message(self, session, message) -> None:
async def on_unknown_message(self, session: SessionManager, message: Chunk) -> None:
logger.warning(f"Unknown message {str(message)}")

async def on_stream_closed(self, session: SessionManager, exception: StreamClosedException) -> None:
Expand All @@ -180,42 +180,42 @@ def __init__(self, controller: BaseRTMPController) -> None:


class SimpleRTMPServer:
def __init__(self):
def __init__(self) -> None:
self.server = None
self.on_start = None
self.on_stop = None

def _signal_on_start(self):
def _signal_on_start(self) -> None:
if self.on_start:
self.on_start()

def _signal_on_stop(self):
def _signal_on_stop(self) -> None:
if self.on_stop:
self.on_stop()

async def create(self, host: str, port: int):
async def create(self, host: str, port: int) -> None:
loop = asyncio.get_event_loop()
self.server = await loop.create_server(
lambda: RTMPProtocol(controller=SimpleRTMPController()),
host=host,
port=port,
)

async def start(self):
async def start(self) -> None:
addr = self.server.sockets[0].getsockname()
await self.server.start_serving()
self._signal_on_start()
logger.info(f"Serving on {addr}")

async def wait_closed(self):
async def wait_closed(self) -> None:
await self.server.wait_closed()

async def stop(self):
async def stop(self) -> None:
self.server.close()
self._signal_on_stop()


async def main():
async def main() -> None:
server = SimpleRTMPServer()
await server.create(host="0.0.0.0", port=1935)
await server.start()
Expand Down
17 changes: 10 additions & 7 deletions pyrtmp/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from asyncio import StreamReader, StreamWriter
from collections.abc import Generator

from bitstring import BitStream
Expand All @@ -10,7 +11,9 @@


class SessionManager:
def __init__(self, reader, writer, reader_chunk_size=128, writer_chunk_size=128) -> None:
def __init__(
self, reader: StreamReader, writer: StreamWriter, reader_chunk_size: int = 128, writer_chunk_size: int = 128
) -> None:
self.reader = reader
self.writer = writer
self.reader_chunk_size = reader_chunk_size
Expand All @@ -22,22 +25,22 @@ def __init__(self, reader, writer, reader_chunk_size=128, writer_chunk_size=128)
super().__init__()

@property
def total_read_bytes(self):
def total_read_bytes(self) -> int:
return self.fifo_reader.total_bytes

@property
def peername(self):
def peername(self) -> str:
a, b = self.writer.get_extra_info("peername")
return f"{a}:{b}"

def set_latest_chunk(self, chunk: RawChunk):
def set_latest_chunk(self, chunk: RawChunk) -> None:
self.latest_chunks[str(chunk.chunk_id)] = chunk
self.latest_chunks["latest"] = chunk

def get_previous_chunk(self, chunk_id: int) -> RawChunk:
return self.latest_chunks[str(chunk_id)]

async def handshake(self):
async def handshake(self) -> None:
# read c0c1
c0 = await C0.from_stream(self.fifo_reader)
c1 = await C1.from_stream(self.fifo_reader)
Expand Down Expand Up @@ -179,12 +182,12 @@ async def read_raw_chunk(self) -> RawChunk:
# return
return instance

def write_chunk_to_stream(self, chunk: Chunk):
def write_chunk_to_stream(self, chunk: Chunk) -> None:
chunks = chunk.to_raw_chunks(self.writer_chunk_size, self.previous_chunk_for_writing)
for chunk in chunks:
self.writer.write(chunk.to_bytes())
self.previous_chunk_for_writing = chunk

async def drain(self):
async def drain(self) -> None:
self.previous_chunk_for_writing = None
await self.writer.drain()

0 comments on commit 3e56d13

Please sign in to comment.