diff --git a/tests/test_wire.py b/tests/test_wire.py index 9438a374..b547f7ad 100644 --- a/tests/test_wire.py +++ b/tests/test_wire.py @@ -11,6 +11,7 @@ from tests.test_users import ADMIN_USER_ID from tests.test_download import setup_embeds from superdesk import get_resource_service +from unittest.mock import patch def test_item_detail(client): @@ -63,6 +64,10 @@ def test_share_items(client, app): def get_bookmarks_count(client, user): + with client.session_transaction() as session: + session['user'] = user + session['user_type'] = 'public' + resp = client.get('/api/wire_search?bookmarks=%s' % str(user)) assert resp.status_code == 200 data = json.loads(resp.get_data()) @@ -70,7 +75,6 @@ def get_bookmarks_count(client, user): def test_bookmarks(client, app): - pass; user_id = get_admin_user_id(app) assert user_id @@ -587,8 +591,15 @@ def test_search_by_products_and_filtered_by_embargoe(client, app): 'embargoed': (datetime.now() + timedelta(days=10)).replace(tzinfo=pytz.UTC), 'products': [{'code': '10'}] }]) - items = get_resource_service('wire_search').get_product_items(10, 20) - assert 0 == len(items) + + # with app.test_request_context(): + mock_user = {'_id': 'test_user_id', 'user_type': 'administrator'} + + # Use a context manager to patch get_user + with patch('newsroom.wire.search.get_user') as mock_get_user: + mock_get_user.return_value = mock_user + items = get_resource_service('wire_search').get_product_items(10, 20) + assert 0 == len(items) # ex-embargoed item is fetched app.data.insert('items', [{ @@ -597,9 +608,12 @@ def test_search_by_products_and_filtered_by_embargoe(client, app): 'embargoed': (datetime.now() - timedelta(days=10)).replace(tzinfo=pytz.UTC), 'products': [{'code': '10'}] }]) - items = get_resource_service('wire_search').get_product_items(10, 20) - assert 1 == len(items) - assert items[0]['headline'] == 'china story' + + with patch('newsroom.wire.search.get_user') as mock_get_user: + mock_get_user.return_value = mock_user + items = get_resource_service('wire_search').get_product_items(10, 20) + assert 1 == len(items) + assert items[0]['headline'] == 'china story' def test_wire_delete(client, app):