Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
coletdjnz committed Nov 2, 2024
1 parent 85ca821 commit 7469e4d
Showing 1 changed file with 140 additions and 128 deletions.
268 changes: 140 additions & 128 deletions yt_dlp_plugins/extractor/_ytse/downloader/sabr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def download(self):
media_type=media_type,
)

requests = 0
request_number = 0
# add a small buffer to account for small difference in format length
while (not self.total_duration_ms) or self.client_abr_state.start_time_ms + 100 < self.total_duration_ms:
po_token = self.po_token_fn()
Expand Down Expand Up @@ -138,9 +138,11 @@ def download(self):
if not self.total_duration_ms:
self.total_duration_ms = first_format.total_duration_ms

requests += 1
# if requests == 2:
# break
if len(self.header_id_to_format_map):
self.fd.report_warning('Extraneous header IDs left')
self.header_id_to_format_map.clear()

request_number += 1

for initialized_format in self.initialized_formats.values():
initialized_format.requested_format.write_callback(b'', close=True)
Expand All @@ -156,140 +158,150 @@ def parse_ump_response(self, response):
ump = UMPParser(response)
for part in ump.iter_parts():
if part.part_type == UMPPartType.MEDIA_HEADER:
media_header = protobug.loads(part.data, MediaHeader)
self.write_ump_debug(part, f'Parsed header: {media_header} Data: {part.get_b64_str()}')

if not media_header.format_id:
continue
initialized_format = self.initialized_formats.get(get_format_key(media_header.format_id))
if not initialized_format:
self.write_ump_warning(part, f'Initialized format not found for {media_header.format_id}')
continue

sequence_number = media_header.sequence_number
if (sequence_number or 0) in initialized_format.sequences:
self.write_ump_warning(part, f'Sequence {sequence_number} already found, skipping')
continue

is_init_segment = media_header.is_init_segment

initialized_format.sequences[sequence_number or 0] = Sequence(
format_id=media_header.format_id,
is_init_segment=is_init_segment,
duration_ms=media_header.duration_ms,
start_data_range=media_header.start_data_range,
sequence_number=sequence_number,
content_length=media_header.content_length,
start_ms=media_header.start_ms
)

self.header_id_to_format_map[media_header.header_id] = get_format_key(media_header.format_id)

if not is_init_segment:
initialized_format.buffered_range.end_segment_index += 1
initialized_format.buffered_range.duration_ms += media_header.duration_ms

if initialized_format.buffered_range.end_segment_index != sequence_number:
self.write_ump_warning(part, f'End segment index mismatch: {initialized_format.buffered_range.end_segment_index} != {sequence_number}')

initialized_format.current_content_length += media_header.content_length
initialized_format.current_duration_ms += media_header.duration_ms or 0

continue

self.process_media_header(part)
elif part.part_type == UMPPartType.MEDIA:
header_id = part.data[0]
self.write_ump_debug(part, f'Header ID: {header_id}')

format_key = self.header_id_to_format_map.get(header_id)
if not format_key:
self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
continue

initialized_format = self.initialized_formats.get(format_key)

if not initialized_format:
self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
continue

# self.write_ump_debug(part, f'Writing {len(part.data[1:])} bytes to initialized format {initialized_format}')

write_callback = initialized_format.requested_format.write_callback
write_callback(part.data[1:])

self.process_media(part)
elif part.part_type == UMPPartType.MEDIA_END:
header_id = part.data[0]
self.write_ump_debug(part, f' Header ID: {header_id}')
self.header_id_to_format_map.pop(header_id, None)
continue
self.process_media_end(part)
elif part.part_type == UMPPartType.STREAM_PROTECTION_STATUS:
sps = protobug.loads(part.data, StreamProtectionStatus)
self.write_ump_debug(part, f'Status: {StreamProtectionStatus.Status(sps.status).name} Data: {part.get_b64_str()}')
if sps.status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
response.close()
self.report_error('StreamProtectionStatus: Attestation Required (missing PO Token?)')
return False

self.process_stream_protection_status(part)
elif part.part_type == UMPPartType.SABR_REDIRECT:
sabr_redirect = protobug.loads(part.data, SabrRedirect)
self.server_abr_streaming_url = sabr_redirect.redirect_url
self.write_ump_debug(part, f'New URL: {self.server_abr_streaming_url}')
if not self.server_abr_streaming_url:
response.close()
self.report_error('SABRRedirect: Invalid redirect URL')
return False

self.process_sabr_redirect(part)
elif part.part_type == UMPPartType.FORMAT_INITIALIZATION_METADATA:
fmt_init_metadata = protobug.loads(part.data, FormatInitializationMetadata)
self.write_ump_debug(part, f'Format Initialization Metadata: {fmt_init_metadata} Data: {part.get_b64_str()}')


initialized_format_key = get_format_key(fmt_init_metadata.format_id)

if initialized_format_key in self.initialized_formats:
self.fd.report_warning('Format already initialized')
continue
# find matching requested format key

matching_requested_format = next((format for format in self.requestedFormats if get_format_key(FormatId(itag=format.itag, last_modified=format.last_modified_at) ) == initialized_format_key), None)

if not matching_requested_format:
self.write_ump_warning(part, f'Format {initialized_format_key} not in requested formats.. Ignoring')
continue

initialized_format = InitializedFormat(
format_id=fmt_init_metadata.format_id,
total_duration_ms=fmt_init_metadata.duration_ms,
mime_type=fmt_init_metadata.mime_type,
buffered_range=BufferedRange(
format_id=fmt_init_metadata.format_id,
start_time_ms=0,
duration_ms=0,
start_segment_index=0,
end_segment_index=0,
),
video_id=fmt_init_metadata.video_id,
requested_format=matching_requested_format
)

self.initialized_formats[get_format_key(fmt_init_metadata.format_id)] = initialized_format

self.write_ump_debug(part, f'Initialized Format: {initialized_format}')
continue

self.process_format_initialization_metadata(part)
elif part.part_type == UMPPartType.NEXT_REQUEST_POLICY:
self.write_ump_debug(part, f'Next Request Policy Data: {part.get_b64_str()}')
next_request_policy = protobug.loads(part.data, NextRequestPolicy)

playback_cookie = next_request_policy.playback_cookie
if playback_cookie:
self.playback_cookie = playback_cookie
self.write_ump_debug(part, f'Playback Cookie: {playback_cookie}')
continue
self.process_next_request_policy(part)
else:
self.write_ump_warning(part, f'Unhandled part type: {part.part_type.name}:{part.part_id} Data: {part.get_b64_str()}')
continue

def process_media_header(self, part: UMPPart):
media_header = protobug.loads(part.data, MediaHeader)
self.write_ump_debug(part, f'Parsed header: {media_header} Data: {part.get_b64_str()}')
if not media_header.format_id:
self.write_ump_warning(part, 'Format ID not found')
return
initialized_format = self.initialized_formats.get(get_format_key(media_header.format_id))
if not initialized_format:
self.write_ump_warning(part, f'Initialized format not found for {media_header.format_id}')
return

sequence_number = media_header.sequence_number
if (sequence_number or 0) in initialized_format.sequences:
self.write_ump_warning(part, f'Sequence {sequence_number} already found, skipping')
return

is_init_segment = media_header.is_init_segment

initialized_format.sequences[sequence_number or 0] = Sequence(
format_id=media_header.format_id,
is_init_segment=is_init_segment,
duration_ms=media_header.duration_ms,
start_data_range=media_header.start_data_range,
sequence_number=sequence_number,
content_length=media_header.content_length,
start_ms=media_header.start_ms
)

self.header_id_to_format_map[media_header.header_id] = get_format_key(media_header.format_id)

if not is_init_segment:
initialized_format.buffered_range.end_segment_index += 1
initialized_format.buffered_range.duration_ms += media_header.duration_ms

if initialized_format.buffered_range.end_segment_index != sequence_number:
self.write_ump_warning(part, f'End segment index mismatch: {initialized_format.buffered_range.end_segment_index} != {sequence_number}')

initialized_format.current_content_length += media_header.content_length
initialized_format.current_duration_ms += media_header.duration_ms or 0

def process_media(self, part: UMPPart):
header_id = part.data[0]
self.write_ump_debug(part, f'Header ID: {header_id}')

format_key = self.header_id_to_format_map.get(header_id)
if not format_key:
self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
return

initialized_format = self.initialized_formats.get(format_key)

if not initialized_format:
self.write_ump_warning(part, f'Initialized Format not found for header ID {header_id}')
return

# self.write_ump_debug(part, f'Writing {len(part.data[1:])} bytes to initialized format {initialized_format}')

# todo: improve write callback
write_callback = initialized_format.requested_format.write_callback
write_callback(part.data[1:])

def process_media_end(self, part: UMPPart):
header_id = part.data[0]
self.write_ump_debug(part, f' Header ID: {header_id}')
self.header_id_to_format_map.pop(header_id, None)

def process_stream_protection_status(self, part: UMPPart):
sps = protobug.loads(part.data, StreamProtectionStatus)
self.write_ump_debug(part, f'Status: {StreamProtectionStatus.Status(sps.status).name} Data: {part.get_b64_str()}')
if sps.status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
self.fd.report_error('StreamProtectionStatus: Attestation Required (missing PO Token?)')

def process_sabr_redirect(self, part: UMPPart):
sabr_redirect = protobug.loads(part.data, SabrRedirect)
self.server_abr_streaming_url = sabr_redirect.redirect_url
self.write_ump_debug(part, f'New URL: {self.server_abr_streaming_url}')

# todo: validate redirect URL
if not self.server_abr_streaming_url:
self.fd.report_error('SABRRedirect: Invalid redirect URL')

def process_format_initialization_metadata(self, part: UMPPart):
fmt_init_metadata = protobug.loads(part.data, FormatInitializationMetadata)
self.write_ump_debug(part, f'Format Initialization Metadata: {fmt_init_metadata} Data: {part.get_b64_str()}')

initialized_format_key = get_format_key(fmt_init_metadata.format_id)

if initialized_format_key in self.initialized_formats:
self.fd.report_warning('Format already initialized')
return
# find matching requested format key

matching_requested_format = next((format for format in self.requestedFormats if get_format_key(FormatId(itag=format.itag, last_modified=format.last_modified_at) ) == initialized_format_key), None)

if not matching_requested_format:
self.write_ump_warning(part, f'Format {initialized_format_key} not in requested formats.. Ignoring')
return

initialized_format = InitializedFormat(
format_id=fmt_init_metadata.format_id,
total_duration_ms=fmt_init_metadata.duration_ms,
mime_type=fmt_init_metadata.mime_type,
buffered_range=BufferedRange(
format_id=fmt_init_metadata.format_id,
start_time_ms=0,
duration_ms=0,
start_segment_index=0,
end_segment_index=0,
),
video_id=fmt_init_metadata.video_id,
requested_format=matching_requested_format
)

self.initialized_formats[get_format_key(fmt_init_metadata.format_id)] = initialized_format

self.write_ump_debug(part, f'Initialized Format: {initialized_format}')\


def process_next_request_policy(self, part: UMPPart):
self.write_ump_debug(part, f'Next Request Policy Data: {part.get_b64_str()}')
next_request_policy = protobug.loads(part.data, NextRequestPolicy)

playback_cookie = next_request_policy.playback_cookie
if playback_cookie:
self.playback_cookie = playback_cookie
self.write_ump_debug(part, f'Playback Cookie: {playback_cookie}')


class SABRFD(FileDownloader):

Expand Down

0 comments on commit 7469e4d

Please sign in to comment.