Skip to content

Commit

Permalink
Rework duration / buffered ranges handling
Browse files Browse the repository at this point in the history
  • Loading branch information
coletdjnz committed Nov 6, 2024
1 parent 6aab1b4 commit b144909
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 69 deletions.
159 changes: 91 additions & 68 deletions yt_dlp_plugins/extractor/_ytse/downloader/sabr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import dataclasses
import enum
import math
import time
from typing import List
import protobug
from yt_dlp import traverse_obj, int_or_none, DownloadError
from yt_dlp.downloader import FileDownloader
Expand All @@ -11,6 +13,7 @@
from yt_dlp_plugins.extractor._ytse.protos._buffered_range import BufferedRange
from yt_dlp_plugins.extractor._ytse.protos._format_id import FormatId
from yt_dlp_plugins.extractor._ytse.protos._streamer_context import StreamerContext, ClientInfo
from yt_dlp_plugins.extractor._ytse.protos._time_range import TimeRange
from yt_dlp_plugins.extractor._ytse.ump import UMPParser, UMPPart, UMPPartType


Expand Down Expand Up @@ -48,14 +51,12 @@ class Sequence:
class InitializedFormat:
format_id: FormatId
video_id: str
buffered_range: BufferedRange
requested_format: SABRFormat
total_duration_ms: int = 0
duration_ms: int = 0
end_time_ms: int = 0
current_content_length: int = 0
current_duration_ms: int = 0
mime_type: str = None
sequences: dict[int, Sequence] = dataclasses.field(default_factory=dict)
buffered_ranges: List[BufferedRange] = dataclasses.field(default_factory=list)


class SABRStream:
Expand All @@ -78,6 +79,8 @@ def __init__(self, fd, server_abr_streaming_url: str, video_playback_ustreamer_c

self.total_duration_ms = None

self.sabr_seeked = False

def download(self):
video_formats = [format for format in self.requestedFormats if format.format_type == FormatType.VIDEO]
audio_formats = [format for format in self.requestedFormats if format.format_type == FormatType.AUDIO]
Expand All @@ -92,18 +95,12 @@ def download(self):

# initialize client abr state
self.client_abr_state = ClientAbrState(
last_manual_direction=0,
time_since_last_manual_format_selection_ms=0,
visibility=0,
start_time_ms=0, # todo
#quality="720",

media_type=media_type,
start_time_ms=0, # current duration
media_type=media_type,
)

request_number = 0
# add a small buffer to account for small difference in format length
while self.live_metadata or not self.total_duration_ms or self.client_abr_state.start_time_ms + 100 < self.total_duration_ms:
while self.live_metadata or not self.total_duration_ms or self.client_abr_state.start_time_ms < self.total_duration_ms:
po_token = self.po_token_fn()
vpabr = VideoPlaybackAbrRequest(
client_abr_state=self.client_abr_state,
Expand All @@ -118,7 +115,10 @@ def download(self):
playback_cookie=self.next_request_policy and protobug.dumps(self.next_request_policy.playback_cookie),
client_info=self.client_info
),
buffered_ranges=[initialized_format.buffered_range for initialized_format in self.initialized_formats.values()],
buffered_ranges=[
buffered_range for initialized_format in self.initialized_formats.values()
for buffered_range in initialized_format.buffered_ranges
],
)
payload = protobug.dumps(vpabr)

Expand All @@ -130,38 +130,49 @@ def download(self):
method='POST',
data=payload,
query={'rn': request_number},
headers={'content-type': 'application/x-protobuf'}
)
)

self.parse_ump_response(response)

ivs = list(self.initialized_formats.values())
if len(self.header_id_to_format_map):
self.fd.report_warning('Extraneous header IDs left')
self.header_id_to_format_map.clear()

current_buffered_ranges = [initialized_format.buffered_ranges[-1] for initialized_format in self.initialized_formats.values() if initialized_format.buffered_ranges]

# choose format that is the most behind
min_buffered_duration_ms = min(current_buffered_ranges, key=lambda x: x.duration_ms).duration_ms if current_buffered_ranges else 0

format_min_duration = ivs and min(self.initialized_formats.values(), key=lambda x: x.current_duration_ms)
next_request_backoff_ms = (self.next_request_policy and self.next_request_policy.backoff_time_ms) or 0

# next request policy backoff_time_ms is the minimum to increment start_time_ms by
self.client_abr_state.start_time_ms = max(
(format_min_duration and format_min_duration.current_duration_ms) or 0,
self.client_abr_state.start_time_ms + ((self.next_request_policy and self.next_request_policy.backoff_time_ms) or 0),
self.client_abr_state.start_time_ms + 1000
min_buffered_duration_ms,
# next request policy backoff_time_ms is the minimum to increment start_time_ms by
self.client_abr_state.start_time_ms + next_request_backoff_ms,
# self.client_abr_state.start_time_ms + 1000 # guard, always increment by at least 1 second
)

if len(self.header_id_to_format_map):
self.fd.report_warning('Extraneous header IDs left')
self.header_id_to_format_map.clear()

if self.live_metadata and self.client_abr_state.start_time_ms >= self.total_duration_ms:
self.client_abr_state.start_time_ms = self.total_duration_ms
wait_time = ((self.next_request_policy and self.next_request_policy.backoff_time_ms) or 0) or (self.live_metadata.target_duration_sec * 1000 * 2)
wait_time = next_request_backoff_ms or (self.live_metadata.target_duration_sec * 1000 * 2) # lets wait a couple segments worth of time
self.fd.write_debug(f'sleeping {wait_time / 1000} seconds')
time.sleep(wait_time / 1000)

self.next_request_policy = None
# reset live_metadata in case yt stops sending it...
self.next_request_policy = self.live_metadata = None
self.sabr_seeked = False

request_number += 1

self.fd.write_debug(f'{self.client_abr_state.start_time_ms}/{self.total_duration_ms}')
self.fd.write_debug(f'Progress: {self.client_abr_state.start_time_ms}/{self.total_duration_ms}')

# The response should automatically close when all data is read, but just in case...
if not response.closed:
response.close()

# todo: improve callback
for initialized_format in self.initialized_formats.values():
initialized_format.requested_format.write_callback(b'', close=True)

Expand Down Expand Up @@ -219,15 +230,18 @@ def process_media_header(self, part: UMPPart):

is_init_segment = media_header.is_init_segment

# duration_ms may not be provided - instead time_range is given (ios)
duration_ms = (
media_header.duration_ms
or (media_header.time_range and int((media_header.time_range.duration / media_header.time_range.timescale) * 1000))
or 0)
if self.live_metadata and not duration_ms:
duration_ms = self.live_metadata.target_duration_sec * 1000
time_range = media_header.time_range

# Calculate duration of this segment
# For videos, either duration_ms or time_range should be present
# For live streams, calculate segment duration based on live metadata target segment duration
duration_ms = (
media_header.duration_ms
or (time_range and time_range.get_duration_ms())
or self.live_metadata and self.live_metadata.target_duration_sec * 1000
or 0)

start_ms = media_header.start_ms or (time_range and time_range.get_start_ms()) or 0

initialized_format.sequences[sequence_number or 0] = Sequence(
format_id=media_header.format_id,
Expand All @@ -236,31 +250,51 @@ def process_media_header(self, part: UMPPart):
start_data_range=media_header.start_data_range,
sequence_number=sequence_number,
content_length=media_header.content_length,
start_ms=media_header.start_ms,
start_ms=start_ms,
)

self.header_id_to_format_map[media_header.header_id] = get_format_key(media_header.format_id)

if not is_init_segment and sequence_number != 0:
if not is_init_segment:
current_buffered_range = initialized_format.buffered_ranges[-1] if initialized_format.buffered_ranges else None

# todo: if we sabr seek, then we get two segments in same request, we end up creating two buffered ranges.
# Perhaps we should have sabr_seeked as part of initialized_format?
if not current_buffered_range or self.sabr_seeked:
initialized_format.buffered_ranges.append(BufferedRange(
format_id=media_header.format_id,
start_time_ms=start_ms,
duration_ms=duration_ms,
start_segment_index=sequence_number,
end_segment_index=sequence_number,
time_range=TimeRange(
start=start_ms,
duration=duration_ms,
timescale=1000 # ms
)
))
self.write_ump_debug(part, f'Created new buffered range for {media_header.format_id} (sabr seeked={self.sabr_seeked}): {initialized_format.buffered_ranges[-1]}')
return

if initialized_format.buffered_range.end_segment_index + 1 != sequence_number:
self.write_ump_warning(part, f'End segment index mismatch: {initialized_format.buffered_range.end_segment_index + 1} != {sequence_number}. Jumping to {sequence_number}.')
end_segment_index = current_buffered_range.end_segment_index or 0
if end_segment_index != 0 and end_segment_index + 1 != sequence_number:
raise DownloadError(part, f'End segment index mismatch: {end_segment_index + 1} != {sequence_number}. Buffered Range: {current_buffered_range}')

initialized_format.buffered_range.end_segment_index = sequence_number
initialized_format.buffered_range.duration_ms += duration_ms
current_buffered_range.end_segment_index = sequence_number

content_length = media_header.content_length
if content_length:
initialized_format.current_content_length = content_length
initialized_format.current_duration_ms += duration_ms or 0
# We need to increment both duration_ms and time_range.duration
current_buffered_range.duration_ms += duration_ms
if current_buffered_range.time_range.timescale != 1000:
raise DownloadError(f'Timescale not 1000 in buffered range time_range: {current_buffered_range}')
current_buffered_range.time_range.duration += duration_ms

def process_media(self, part: UMPPart):
header_id = part.data[0]
self.write_ump_debug(part, f'Header ID: {header_id}')
# self.write_ump_debug(part, f'Header ID: {header_id}')

format_key = self.header_id_to_format_map.get(header_id)
if not format_key:
self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
#self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
return

initialized_format = self.initialized_formats.get(format_key)
Expand Down Expand Up @@ -312,7 +346,7 @@ def process_format_initialization_metadata(self, part: UMPPart):
initialized_format_key = get_format_key(fmt_init_metadata.format_id)

if initialized_format_key in self.initialized_formats:
self.fd.report_warning('Format already initialized')
self.write_ump_debug(part, 'Format already initialized')
return
# find matching requested format key

Expand All @@ -322,28 +356,21 @@ def process_format_initialization_metadata(self, part: UMPPart):
self.write_ump_warning(part, f'Format {initialized_format_key} not in requested formats.. Ignoring')
return

duration_ms = fmt_init_metadata.duration and ((fmt_init_metadata.duration // fmt_init_metadata.duration_timescale) * 1000)

initialized_format = InitializedFormat(
format_id=fmt_init_metadata.format_id,
total_duration_ms=int((fmt_init_metadata.duration / fmt_init_metadata.duration_timescale) * 1000),
duration_ms=duration_ms,
end_time_ms=fmt_init_metadata.end_time_ms,
mime_type=fmt_init_metadata.mime_type,
buffered_range=BufferedRange(
format_id=fmt_init_metadata.format_id,
start_time_ms=0,
duration_ms=0,
start_segment_index=0,
end_segment_index=0,
# todo: also use time range for ios? perhaps will avoid duplicate segments..
),
video_id=fmt_init_metadata.video_id,
requested_format=matching_requested_format,
)
# fmt total_duration_ms may be inaccurate - end_time_ms is usually accurate though? (ios)
self.total_duration_ms = max(self.total_duration_ms or 0, fmt_init_metadata.end_time_ms or 0)
self.total_duration_ms = max(self.total_duration_ms or 0, fmt_init_metadata.end_time_ms or 0, duration_ms or 0)

self.initialized_formats[get_format_key(fmt_init_metadata.format_id)] = initialized_format

self.write_ump_debug(part, f'Initialized Format: {initialized_format}')\
self.write_ump_debug(part, f'Initialized Format: {initialized_format}')


def process_next_request_policy(self, part: UMPPart):
Expand All @@ -352,20 +379,16 @@ def process_next_request_policy(self, part: UMPPart):

def process_sabr_seek(self, part: UMPPart):
sabr_seek = protobug.loads(part.data, SabrSeek)
seek_to = (sabr_seek.start / sabr_seek.timescale) * 1000
seek_to = math.ceil((sabr_seek.start / sabr_seek.timescale) * 1000)
self.write_ump_debug(part, f'Sabr Seek: {sabr_seek} Data: {part.get_b64_str()}')
self.write_ump_debug(part, f'Seeking to {seek_to}ms')

for initialized_format in self.initialized_formats.values():
initialized_format.buffered_range.start_time_ms = int(seek_to)
initialized_format.buffered_range.duration_ms = 0
initialized_format.current_content_length = 0
initialized_format.current_duration_ms = int(seek_to)
initialized_format.sequences.clear()
self.client_abr_state.start_time_ms = seek_to
self.sabr_seeked = True

def process_sabr_error(self, part: UMPPart):
sabr_error = protobug.loads(part.data, SabrError)
self.fd.report_error(f'SABR Error: {sabr_error} Data: {part.get_b64_str()}')
raise DownloadError(f'SABR Error: {sabr_error} Data: {part.get_b64_str()}')


class SABRFD(FileDownloader):
Expand Down Expand Up @@ -412,7 +435,7 @@ def callback(data, close=False):
self.write_debug(f'Closing {filename}')
stream.close()
return True
self.write_debug(f'Writing {len(data)} bytes to {filename}')
#self.write_debug(f'Writing {len(data)} bytes to {filename}')
try:
stream.write(data)
except OSError as err:
Expand Down
15 changes: 14 additions & 1 deletion yt_dlp_plugins/extractor/_ytse/protos/_time_range.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import typing
import protobug
import math


@protobug.message
class TimeRange:
start: typing.Optional[protobug.Int64] = protobug.field(1, default=None)
duration: typing.Optional[protobug.Int64] = protobug.field(2, default=None)
timescale: typing.Optional[protobug.Int32] = protobug.field(3, default=None)
timescale: typing.Optional[protobug.Int32] = protobug.field(3, default=None)

def get_duration_ms(self):
if not self.duration or not self.timescale:
return None

return math.ceil((self.duration / self.timescale) * 1000)

def get_start_ms(self):
if not self.start or not self.timescale:
return None

return math.ceil((self.start / self.timescale) * 1000)

0 comments on commit b144909

Please sign in to comment.