Skip to content

Commit

Permalink
[utils] traverse_obj: Fix more bugs
Browse files Browse the repository at this point in the history
and cleanup uses of `default=[]`

Continued from b1bde57
  • Loading branch information
Grub4K authored and pukkandan committed Feb 10, 2023
1 parent c0cd13f commit 6839ae1
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 67 deletions.
75 changes: 48 additions & 27 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,7 +2000,7 @@ def test_traverse_obj(self):

# Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
(item for item in _TEST_DATA.values() if item not in (None, [], {})),
(item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values')
Expand Down Expand Up @@ -2095,7 +2095,7 @@ def test_traverse_obj(self):
msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {},
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {0: ...},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
msg='use nested `default` when nested dict key fails and `default`')
Expand Down Expand Up @@ -2124,34 +2124,55 @@ def test_traverse_obj(self):
msg='if branched but not successful return `[]`, not `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
msg='if branched but object is empty return `[]`, not `default`')
self.assertEqual(traverse_obj(None, ...), [],
msg='if branched but object is `None` return `[]`, not `default`')
self.assertEqual(traverse_obj({0: None}, (0, ...)), [],
msg='if branched but state is `None` return `[]`, not `default`')

branching_paths = [
('fail', ...),
(..., 'fail'),
100 * ('fail',) + (...,),
(...,) + 100 * ('fail',),
]
for branching_path in branching_paths:
self.assertEqual(traverse_obj({}, branching_path), [],
msg='if branched but state is `None`, return `[]` (not `default`)')
self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
msg='if branching in last alternative and previous did match, return single value')
self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
msg='if branching in first alternative and non-branching path does match, return single value')
self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
msg='if branching in first alternative and non-branching path does not match, return `default`')

# Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str',
msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0',
msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
expected_type=lambda _: 1 / 0), None,
msg='wrap expected_type fuction in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],
msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100},
msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'},
msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1,
msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}},
msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}],
msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4],
msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [],
msg='expected_type regression for type matching in dict result')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
None, msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
'0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
None, msg='wrap expected_type fuction in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str),
['str'], msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
{0: 100}, msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int),
1, msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)),
[{0: ...}, {0: ...}], msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
[4], msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int),
[], msg='expected_type regression for type matching in dict result')

# Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]}
Expand Down
2 changes: 1 addition & 1 deletion yt_dlp/downloader/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def download_and_append_fragments_multiple(self, *args, **kwargs):
max_workers = self.params.get('concurrent_fragment_downloads', 1)
if max_progress > 1:
self._prepare_multiline_status(max_progress)
is_live = any(traverse_obj(args, (..., 2, 'is_live'), default=[]))
is_live = any(traverse_obj(args, (..., 2, 'is_live')))

def thread_func(idx, ctx, fragments, info_dict, tpe):
ctx['max_progress'] = max_progress
Expand Down
4 changes: 2 additions & 2 deletions yt_dlp/extractor/abematv.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _real_extract(self, url):
f'https://api.abema.io/v1/video/programs/{video_id}', video_id,
note='Checking playability',
headers=headers)
ondemand_types = traverse_obj(api_response, ('terms', ..., 'onDemandType'), default=[])
ondemand_types = traverse_obj(api_response, ('terms', ..., 'onDemandType'))
if 3 not in ondemand_types:
# cannot acquire decryption key for these streams
self.report_warning('This is a premium-only stream')
Expand Down Expand Up @@ -489,7 +489,7 @@ def _fetch_page(self, playlist_id, series_version, page):
})
yield from (
self.url_result(f'https://abema.tv/video/episode/{x}')
for x in traverse_obj(programs, ('programs', ..., 'id'), default=[]))
for x in traverse_obj(programs, ('programs', ..., 'id')))

def _entries(self, playlist_id, series_version):
return OnDemandPagedList(
Expand Down
2 changes: 1 addition & 1 deletion yt_dlp/extractor/gamejolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _get_comments(self, post_num_id, post_hash_id):
post_hash_id, note='Downloading comments list page %d' % page)
if not comments_data.get('comments'):
break
for comment in traverse_obj(comments_data, (('comments', 'childComments'), ...), expected_type=dict, default=[]):
for comment in traverse_obj(comments_data, (('comments', 'childComments'), ...), expected_type=dict):
yield {
'id': comment['id'],
'text': self._parse_content_as_text(
Expand Down
8 changes: 4 additions & 4 deletions yt_dlp/extractor/iqiyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def _real_extract(self, url):
'langCode': self._get_cookie('lang', 'en_us'),
'deviceId': self._get_cookie('QC005', '')
}, fatal=False)
ut_list = traverse_obj(vip_data, ('data', 'all_vip', ..., 'vipType'), expected_type=str_or_none, default=[])
ut_list = traverse_obj(vip_data, ('data', 'all_vip', ..., 'vipType'), expected_type=str_or_none)
else:
ut_list = ['0']

Expand Down Expand Up @@ -617,7 +617,7 @@ def _real_extract(self, url):
self.report_warning('This preview video is limited%s' % format_field(preview_time, None, ' to %s seconds'))

# TODO: Extract audio-only formats
for bid in set(traverse_obj(initial_format_data, ('program', 'video', ..., 'bid'), expected_type=str_or_none, default=[])):
for bid in set(traverse_obj(initial_format_data, ('program', 'video', ..., 'bid'), expected_type=str_or_none)):
dash_path = dash_paths.get(bid)
if not dash_path:
self.report_warning(f'Unknown format id: {bid}. It is currently not being extracted')
Expand All @@ -628,7 +628,7 @@ def _real_extract(self, url):
fatal=False), 'data', expected_type=dict)

video_format = traverse_obj(format_data, ('program', 'video', lambda _, v: str(v['bid']) == bid),
expected_type=dict, default=[], get_all=False) or {}
expected_type=dict, get_all=False) or {}
extracted_formats = []
if video_format.get('m3u8Url'):
extracted_formats.extend(self._extract_m3u8_formats(
Expand Down Expand Up @@ -669,7 +669,7 @@ def _real_extract(self, url):
})
formats.extend(extracted_formats)

for sub_format in traverse_obj(initial_format_data, ('program', 'stl', ...), expected_type=dict, default=[]):
for sub_format in traverse_obj(initial_format_data, ('program', 'stl', ...), expected_type=dict):
lang = self._LID_TAGS.get(str_or_none(sub_format.get('lid')), sub_format.get('_name'))
subtitles.setdefault(lang, []).extend([{
'ext': format_ext,
Expand Down
4 changes: 2 additions & 2 deletions yt_dlp/extractor/panopto.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _real_extract(self, url):
return {
'id': video_id,
'title': delivery.get('SessionName'),
'cast': traverse_obj(delivery, ('Contributors', ..., 'DisplayName'), default=[], expected_type=lambda x: x or None),
'cast': traverse_obj(delivery, ('Contributors', ..., 'DisplayName'), expected_type=lambda x: x or None),
'timestamp': session_start_time - 11640000000 if session_start_time else None,
'duration': delivery.get('Duration'),
'thumbnail': base_url + f'/Services/FrameGrabber.svc/FrameRedirect?objectId={video_id}&mode=Delivery&random={random()}',
Expand Down Expand Up @@ -563,7 +563,7 @@ def _extract_folder_metadata(self, base_url, folder_id):
base_url, '/Services/Data.svc/GetFolderInfo', folder_id,
data={'folderID': folder_id}, fatal=False)
return {
'title': get_first(response, 'Name', default=[])
'title': get_first(response, 'Name')
}

def _real_extract(self, url):
Expand Down
2 changes: 1 addition & 1 deletion yt_dlp/extractor/patreon.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _get_comments(self, post_id):
f'posts/{post_id}/comments', post_id, query=params, note='Downloading comments page %d' % page)

cursor = None
for comment in traverse_obj(response, (('data', ('included', lambda _, v: v['type'] == 'comment')), ...), default=[]):
for comment in traverse_obj(response, (('data', ('included', lambda _, v: v['type'] == 'comment')), ...)):
count += 1
comment_id = comment.get('id')
attributes = comment.get('attributes') or {}
Expand Down
4 changes: 2 additions & 2 deletions yt_dlp/extractor/tiktok.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def extract_addr(addr, add_meta={}):
user_url = self._UPLOADER_URL_FORMAT % (traverse_obj(author_info,
'sec_uid', 'id', 'uid', 'unique_id',
expected_type=str_or_none, get_all=False))
labels = traverse_obj(aweme_detail, ('hybrid_label', ..., 'text'), expected_type=str, default=[])
labels = traverse_obj(aweme_detail, ('hybrid_label', ..., 'text'), expected_type=str)

contained_music_track = traverse_obj(
music_info, ('matched_song', 'title'), ('matched_pgc_sound', 'title'), expected_type=str)
Expand Down Expand Up @@ -355,7 +355,7 @@ def _parse_aweme_video_web(self, aweme_detail, webpage_url):
'ext': 'mp4',
'width': width,
'height': height,
} for url in traverse_obj(play_url, (..., 'src'), expected_type=url_or_none, default=[]) if url]
} for url in traverse_obj(play_url, (..., 'src'), expected_type=url_or_none) if url]

download_url = url_or_none(video_info.get('downloadAddr')) or traverse_obj(video_info, ('download', 'url'), expected_type=url_or_none)
if download_url:
Expand Down
Loading

0 comments on commit 6839ae1

Please sign in to comment.