diff --git a/src/reader/_types.py b/src/reader/_types.py index e97b402b..dbcdc068 100644 --- a/src/reader/_types.py +++ b/src/reader/_types.py @@ -180,6 +180,7 @@ def entry_data_from_obj(obj: object) -> EntryData: """ if isinstance(obj, Mapping): obj = SimpleNamespace(**obj) + source_obj = getattr(obj, 'source', None) return EntryData( feed_url=_getattr(obj, 'feed_url', str), id=_getattr(obj, 'id', str), @@ -191,6 +192,7 @@ def entry_data_from_obj(obj: object) -> EntryData: summary=_getattr_optional(obj, 'summary', str), content=tuple(content_from_obj(o) for o in getattr(obj, 'content', ())), enclosures=tuple(enclosure_from_obj(o) for o in getattr(obj, 'enclosures', ())), + source=source_from_obj(source_obj) if source_obj else None, ) @@ -214,6 +216,19 @@ def enclosure_from_obj(obj: object) -> Enclosure: ) +def source_from_obj(obj: object) -> EntrySource: + if isinstance(obj, Mapping): + obj = SimpleNamespace(**obj) + return EntrySource( + url=_getattr_optional(obj, 'url', str), + updated=_getattr_optional_datetime(obj, 'updated'), + title=_getattr_optional(obj, 'title', str), + link=_getattr_optional(obj, 'link', str), + author=_getattr_optional(obj, 'author', str), + subtitle=_getattr_optional(obj, 'subtitle', str), + ) + + def _getattr(obj: object, name: str, type: type[_T]) -> _T: # will raise AttributeError implicitly value = getattr(obj, name) diff --git a/src/reader/core.py b/src/reader/core.py index 670eccf5..3f1eab28 100644 --- a/src/reader/core.py +++ b/src/reader/core.py @@ -1513,6 +1513,7 @@ def add_entry(self, entry: Any, /) -> None: * :attr:`~Entry.summary` * :attr:`~Entry.content` * :attr:`~Entry.enclosures` + * :attr:`~Entry.source` Naive datetimes are normalized by passing them to :meth:`~datetime.datetime.astimezone`. @@ -1532,6 +1533,9 @@ def add_entry(self, entry: Any, /) -> None: .. versionchanged:: 3.0 The ``entry`` argument is now positional-only. + .. versionchanged:: 3.16 + Allow setting :attr:`~Entry.source`. + """ # `entry` is of type Union[EntryDataLikeProtocol, EntryDataTypedDict], diff --git a/tests/test__types.py b/tests/test__types.py index e9b9ec51..5fa4ab7f 100644 --- a/tests/test__types.py +++ b/tests/test__types.py @@ -88,15 +88,13 @@ def test_tristate_filter_argument_error(input): assert 'name' in str(excinfo.value) +@pytest.mark.parametrize('feed_type', ['rss', 'atom']) @pytest.mark.parametrize('data_file', ['full', 'empty']) -def test_entry_data_from_obj(data_dir, data_file): +def test_entry_data_from_obj(data_dir, feed_type, data_file): expected = {'url_base': '', 'rel_base': ''} - exec(data_dir.joinpath(f'{data_file}.rss.py').read_bytes(), expected) + exec(data_dir.joinpath(f'{data_file}.{feed_type}.py').read_bytes(), expected) for i, entry in enumerate(expected['entries']): - # FIXME: temporary during #276 development - entry = entry._replace(source=None) - assert entry == entry_data_from_obj(entry), i entry_dict = entry._asdict() @@ -104,6 +102,8 @@ def test_entry_data_from_obj(data_dir, data_file): entry_dict['content'] = [c._asdict() for c in entry_dict['content']] if 'enclosures' in entry_dict: entry_dict['enclosures'] = [e._asdict() for e in entry_dict['enclosures']] + if entry_dict.get('source'): + entry_dict['source'] = entry_dict['source']._asdict() assert entry == entry_data_from_obj(entry_dict), i