Skip to content

Commit

Permalink
Improve livestream duration handling
Browse files Browse the repository at this point in the history
  • Loading branch information
coletdjnz committed Nov 7, 2024
1 parent b144909 commit 16eb234
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
1 change: 1 addition & 0 deletions utils/mitmweb_sabrdump.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def response(self, flow: http.HTTPFlow) -> None:

elif part.part_type == UMPPartType.LIVE_METADATA:
f.write(f'Live Metadata: {protobug.loads(part.data, LiveMetadata)}\n')
p.write(f'Live Metadata: {protobug.loads(part.data, LiveMetadata)}\n')

elif part.part_type == UMPPartType.MEDIA or part.part_type == UMPPartType.MEDIA_END:
f.write(f'Media Header Id: {part.data[0]}\n')
Expand Down
45 changes: 30 additions & 15 deletions yt_dlp_plugins/extractor/_ytse/downloader/sabr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class InitializedFormat:


class SABRStream:
def __init__(self, fd, server_abr_streaming_url: str, video_playback_ustreamer_config: str, po_token_fn: callable, formats: list[SABRFormat], client_info: ClientInfo):
def __init__(self, fd, server_abr_streaming_url: str, video_playback_ustreamer_config: str, po_token_fn: callable, formats: list[SABRFormat], client_info: ClientInfo, live_segment_target_duration_sec: int = 5):
self.server_abr_streaming_url = server_abr_streaming_url
self.video_playback_ustreamer_config = video_playback_ustreamer_config
self.po_token_fn = po_token_fn
Expand All @@ -80,6 +80,7 @@ def __init__(self, fd, server_abr_streaming_url: str, video_playback_ustreamer_c
self.total_duration_ms = None

self.sabr_seeked = False
self.live_segment_target_duration_sec = live_segment_target_duration_sec

def download(self):
video_formats = [format for format in self.requestedFormats if format.format_type == FormatType.VIDEO]
Expand Down Expand Up @@ -151,17 +152,15 @@ def download(self):
min_buffered_duration_ms,
# next request policy backoff_time_ms is the minimum to increment start_time_ms by
self.client_abr_state.start_time_ms + next_request_backoff_ms,
# self.client_abr_state.start_time_ms + 1000 # guard, always increment by at least 1 second
)

if self.live_metadata and self.client_abr_state.start_time_ms >= self.total_duration_ms:
self.client_abr_state.start_time_ms = self.total_duration_ms
wait_time = next_request_backoff_ms or (self.live_metadata.target_duration_sec * 1000 * 2) # lets wait a couple segments worth of time
wait_time = next_request_backoff_ms + self.live_segment_target_duration_sec * 1000
self.fd.write_debug(f'sleeping {wait_time / 1000} seconds')
time.sleep(wait_time / 1000)

# reset live_metadata in case yt stops sending it...
self.next_request_policy = self.live_metadata = None
self.next_request_policy = None
self.sabr_seeked = False

request_number += 1
Expand Down Expand Up @@ -235,14 +234,16 @@ def process_media_header(self, part: UMPPart):
# Calculate duration of this segment
# For videos, either duration_ms or time_range should be present
# For live streams, calculate segment duration based on live metadata target segment duration


start_ms = media_header.start_ms or (time_range and time_range.get_start_ms()) or 0

duration_ms = (
media_header.duration_ms
or (time_range and time_range.get_duration_ms())
or self.live_metadata and self.live_metadata.target_duration_sec * 1000
or self.live_metadata and self.live_segment_target_duration_sec * 1000
or 0)

start_ms = media_header.start_ms or (time_range and time_range.get_start_ms()) or 0

initialized_format.sequences[sequence_number or 0] = Sequence(
format_id=media_header.format_id,
is_init_segment=is_init_segment,
Expand Down Expand Up @@ -282,11 +283,20 @@ def process_media_header(self, part: UMPPart):

current_buffered_range.end_segment_index = sequence_number

# We need to increment both duration_ms and time_range.duration
current_buffered_range.duration_ms += duration_ms
if current_buffered_range.time_range.timescale != 1000:
raise DownloadError(f'Timescale not 1000 in buffered range time_range: {current_buffered_range}')
current_buffered_range.time_range.duration += duration_ms
if not self.live_metadata:
# We need to increment both duration_ms and time_range.duration
current_buffered_range.duration_ms += duration_ms
current_buffered_range.time_range.duration += duration_ms
else:
# Attempt to keep in sync with livestream, as the segment duration target is not always perfect.
# The server seems to care more about the segment index than the duration.
if current_buffered_range.start_time_ms > start_ms:
raise DownloadError(f'Buffered range start time mismatch: {current_buffered_range.start_time_ms} > {start_ms}')
current_buffered_range.duration_ms = current_buffered_range.time_range.duration = start_ms + duration_ms

if not current_buffered_range.duration_ms - duration_ms == start_ms:
raise DownloadError(f'Buffered range duration mismatch: {current_buffered_range.duration_ms - duration_ms} != {start_ms}')


def process_media(self, part: UMPPart):
header_id = part.data[0]
Expand Down Expand Up @@ -324,7 +334,7 @@ def process_stream_protection_status(self, part: UMPPart):
sps = protobug.loads(part.data, StreamProtectionStatus)
self.write_ump_debug(part, f'Status: {StreamProtectionStatus.Status(sps.status).name} Data: {part.get_b64_str()}')
if sps.status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
self.fd.report_error('StreamProtectionStatus: Attestation Required (missing PO Token?)')
raise DownloadError('StreamProtectionStatus: Attestation Required (missing PO Token?)')

def process_sabr_redirect(self, part: UMPPart):
sabr_redirect = protobug.loads(part.data, SabrRedirect)
Expand Down Expand Up @@ -408,7 +418,12 @@ def real_download(self, filename, info_dict):

# assuming we have only selected audio + video, and they are of the same client, for now.

requested_formats = info_dict['requested_formats']
requested_formats = info_dict.get('requested_formats')

if not requested_formats:
requested_formats = [info_dict]
info_dict['filepath'] = filename


formats = []
po_token_fn = lambda: None
Expand Down
6 changes: 3 additions & 3 deletions yt_dlp_plugins/extractor/_ytse/protos/_live_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class LiveMetadata:
latest_sequence_duration_ms: typing.Optional[protobug.UInt64] = protobug.field(4)

timestamp: typing.Optional[protobug.UInt64] = protobug.field(5)
target_duration_sec: typing.Optional[protobug.UInt32] = protobug.field(10) # duration of each segment, in seconds
unknown_field_10: typing.Optional[protobug.UInt32] = protobug.field(10) # was thought to be target_duration_seconds, but doesn't seem to be

segment_start_duration: typing.Optional[protobug.UInt64] = protobug.field(12) # earliest you can rewind the livestream
segment_start_timescale: typing.Optional[protobug.UInt32] = protobug.field(13)
dvr_start_duration: typing.Optional[protobug.UInt64] = protobug.field(12) # earliest you can rewind the livestream
dvr_start_timescale: typing.Optional[protobug.UInt32] = protobug.field(13)

live_start_duration: typing.Optional[protobug.UInt64] = protobug.field(14) # where SABR seek puts you to start streaming live?
live_start_timescale: typing.Optional[protobug.UInt32] = protobug.field(15)
Expand Down
2 changes: 1 addition & 1 deletion yt_dlp_plugins/extractor/_ytse/protos/_media_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ class Compression(enum.IntEnum):
format_id: typing.Optional[FormatId] = protobug.field(13, default=None)
content_length: typing.Optional[protobug.Int64] = protobug.field(14, default=None)
time_range: typing.Optional[TimeRange] = protobug.field(15, default=None)
unknown_field_16: typing.Optional[protobug.Int32] = protobug.field(16, default=None)
timestamp: typing.Optional[protobug.Int32] = protobug.field(16, default=None)

2 changes: 1 addition & 1 deletion yt_dlp_plugins/extractor/ytse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
if set(downloaders) == {SABRFD} and SABRFD.can_download(info_copy):
return SABRFD

return get_suitable_downloader_original(info_dict, protocol, params, default)
return get_suitable_downloader_original(info_dict, params, default, protocol, to_stdout)


sys.modules.get('yt_dlp.downloader').get_suitable_downloader = get_suitable_downloader
Expand Down

0 comments on commit 16eb234

Please sign in to comment.