diff --git a/coriolis/tests/scheduler/filters/__init__.py b/coriolis/tests/scheduler/filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/coriolis/tests/scheduler/filters/test_base.py b/coriolis/tests/scheduler/filters/test_base.py new file mode 100644 index 000000000..c44636181 --- /dev/null +++ b/coriolis/tests/scheduler/filters/test_base.py @@ -0,0 +1,52 @@ +# Copyright 2024 Cloudbase Solutions Srl +# All Rights Reserved. + +from unittest import mock + +from coriolis.scheduler.filters import base +from coriolis.tests import test_base + + +class BaseServiceFilterTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the BaseServiceFilter class.""" + + @mock.patch.object(base.BaseServiceFilter, '__abstractmethods__', set()) + def setUp(self): + super(BaseServiceFilterTestCase, self).setUp() + self.service_filter = base.BaseServiceFilter() + + def test_is_service_acceptable(self): + self.service_filter.rate_service = mock.Mock() + self.service_filter.rate_service.return_value = 50 + + result = self.service_filter.is_service_acceptable( + mock.sentinel.service) + + self.service_filter.rate_service.assert_called_once_with( + mock.sentinel.service + ) + self.assertTrue(result) + + def test_is_service_acceptable_false(self): + self.service_filter.rate_service = mock.Mock() + self.service_filter.rate_service.return_value = 0 + + result = self.service_filter.is_service_acceptable( + mock.sentinel.service) + + self.service_filter.rate_service.assert_called_once_with( + mock.sentinel.service + ) + self.assertFalse(result) + + def test_filter_services(self): + self.service_filter.is_service_acceptable = mock.Mock() + self.service_filter.is_service_acceptable.side_effect = [ + True, False, True] + + result = self.service_filter.filter_services([mock.sentinel.service1, + mock.sentinel.service2, + mock.sentinel.service3]) + + self.assertEqual(result, [mock.sentinel.service1, + mock.sentinel.service3]) diff --git a/coriolis/tests/scheduler/filters/test_trivial_filters.py b/coriolis/tests/scheduler/filters/test_trivial_filters.py new file mode 100644 index 000000000..eac435cd1 --- /dev/null +++ b/coriolis/tests/scheduler/filters/test_trivial_filters.py @@ -0,0 +1,137 @@ +# Copyright 2024 Cloudbase Solutions Srl +# All Rights Reserved. + +from unittest import mock + +import ddt + +from coriolis.scheduler.filters import trivial_filters +from coriolis.tests import test_base + + +@ddt.ddt +class RegionsFilterTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the RegionsFilter class.""" + + def setUp(self): + super(RegionsFilterTestCase, self).setUp() + self.regions = [mock.sentinel.region1, mock.sentinel.region2] + self.any_region = False + self.regions_filter = trivial_filters.RegionsFilter( + self.regions, self.any_region) + + def test__repr__(self): + result = repr(self.regions_filter) + + self.assertEqual(result, "" % (self.regions, + self.any_region)) + + @ddt.data( + (None, [mock.sentinel.region1, mock.sentinel.region2], 100, False), + ([mock.sentinel.region1, mock.sentinel.region2], + [mock.sentinel.region1, mock.sentinel.region2], 100, False), + ([mock.sentinel.region1, mock.sentinel.region2], [], 0, False), + ([mock.sentinel.region1, mock.sentinel.region2], + [mock.sentinel.region1], 0, False), + ([mock.sentinel.region1, mock.sentinel.region2], + [mock.sentinel.region1], 100, True) + ) + @ddt.unpack + def test_rate_service(self, regions, mapped_regions, expected, any_region): + self.regions_filter._regions = regions + self.regions_filter._any_region = any_region + mock_service = mock.Mock() + mock_service.mapped_regions = [mock.Mock(id=region) for region in + mapped_regions] + + result = self.regions_filter.rate_service(mock_service) + + self.assertEqual(result, expected) + + +class TopicFilterTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the TopicFilter class.""" + + def setUp(self): + super(TopicFilterTestCase, self).setUp() + self.service = mock.Mock() + self.topic = mock.sentinel.topic + self.topic_filter = trivial_filters.TopicFilter(self.topic) + + def test__repr__(self): + result = repr(self.topic_filter) + self.assertEqual(result, "" % self.topic) + + def test_rate_service(self): + self.service.topic = self.topic + + result = self.topic_filter.rate_service(self.service) + self.assertEqual(result, 100) + + def test_rate_service_false(self): + self.service.topic = mock.sentinel.other_topic + + result = self.topic_filter.rate_service(self.service) + self.assertEqual(result, 0) + + +class EnabledFilterTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the EnabledFilter class.""" + + def setUp(self): + super(EnabledFilterTestCase, self).setUp() + self.service = mock.Mock() + self.enabled = True + self.enabled_filter = trivial_filters.EnabledFilter() + + def test__repr__(self): + result = repr(self.enabled_filter) + self.assertEqual(result, "" % self.enabled) + + def test_rate_service(self): + self.service.enabled = True + + result = self.enabled_filter.rate_service(self.service) + self.assertEqual(result, 100) + + def test_rate_service_false(self): + self.service.enabled = False + + result = self.enabled_filter.rate_service(self.service) + self.assertEqual(result, 0) + + +@ddt.ddt +class ProviderTypesFilterTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the ProviderTypesFilter class.""" + + def setUp(self): + super(ProviderTypesFilterTestCase, self).setUp() + self.service = mock.Mock() + self.provider_types = [mock.sentinel.provider_type1, + mock.sentinel.provider_type2] + self.provider_types_filter = trivial_filters.ProviderTypesFilter( + self.provider_types) + + def test__repr__(self): + result = repr(self.provider_types_filter) + self.assertEqual(result, + "" % + self.provider_types) + + @ddt.data( + ({'platform1': ['type1']}, {'platform1': {'types': ['type1']}}, 100), + ({'platform1': ['type1']}, {'platform2': {'types': ['type1']}}, 0), + ({'platform1': ['type1', 'type2']}, + {'platform1': {'types': ['type1']}}, 0) + ) + @ddt.unpack + def test_rate_service(self, provider_requirements, providers, expected): + self.provider_types_filter.\ + _provider_requirements = provider_requirements + self.service.providers = providers + + result = self.provider_types_filter.rate_service(self.service) + + self.assertEqual(result, expected) diff --git a/coriolis/tests/scheduler/rpc/test_client.py b/coriolis/tests/scheduler/rpc/test_client.py new file mode 100644 index 000000000..ddfb5c415 --- /dev/null +++ b/coriolis/tests/scheduler/rpc/test_client.py @@ -0,0 +1,242 @@ +# Copyright 2024 Cloudbase Solutions Srl +# All Rights Reserved. + +import logging +from unittest import mock + +import oslo_messaging + +from coriolis import constants +from coriolis import exception +from coriolis.scheduler.rpc import client +from coriolis.tasks import factory as tasks_factory +from coriolis.tests import test_base + + +class SchedulerClientTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the Coriolis Scheduler Worker RPC client.""" + + def setUp(self): + super(SchedulerClientTestCase, self).setUp() + self.client = client.SchedulerClient() + self.task = {'id': 'task_id', 'task_type': 'task_type'} + self.origin_endpoint = { + 'id': 'origin_id', + 'mapped_regions': [{'id': 'region1'}, {'id': 'region2'}], + 'type': 'origin_type' + } + self.destination_endpoint = { + 'id': 'destination_id', + 'mapped_regions': [{'id': 'region3'}, {'id': 'region4'}], + 'type': 'destination_type' + } + + @mock.patch('coriolis.scheduler.rpc.client.CONF') + @mock.patch.object(oslo_messaging, 'Target') + def test__init__(self, mock_target, mock_conf): + expected_timeout = 120 + mock_conf.scheduler.scheduler_rpc_timeout = expected_timeout + + result = client.SchedulerClient() + mock_target.assert_called_once_with( + topic='coriolis_scheduler', version=client.VERSION) + + self.assertEqual(result._target, mock_target.return_value) + self.assertEqual(result._timeout, expected_timeout) + + def test__init__without_timeout(self): + result = client.SchedulerClient() + self.assertEqual(result._timeout, 60) + + def test__init__with_timeout(self): + result = client.SchedulerClient(timeout=120) + self.assertEqual(result._timeout, 120) + + @mock.patch.object(client.SchedulerClient, '_call') + def test_get_diagnostics(self, mock_call): + ctxt = mock.sentinel.ctxt + result = self.client.get_diagnostics(ctxt) + + mock_call.assert_called_once_with(ctxt, 'get_diagnostics') + self.assertEqual(result, mock_call.return_value) + + @mock.patch.object(client.SchedulerClient, '_call') + def test_get_workers_for_specs(self, mock_call): + ctxt = mock.sentinel.ctxt + provider_requirements = mock.sentinel.provider_requirements + region_sets = mock.sentinel.region_sets + enabled = mock.sentinel.enabled + + result = self.client.get_workers_for_specs( + ctxt, provider_requirements=provider_requirements, + region_sets=region_sets, enabled=enabled) + + mock_call.assert_called_once_with( + ctxt, 'get_workers_for_specs', region_sets=region_sets, + enabled=enabled, provider_requirements=provider_requirements) + self.assertEqual(result, mock_call.return_value) + + @mock.patch('random.choice') + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_any_worker_service(self, mock_get_workers_for_specs, + mock_random_choice): + ctxt = mock.sentinel.ctxt + raise_if_none = mock.sentinel.raise_if_none + mock_service = mock.MagicMock() + mock_get_workers_for_specs.return_value = mock_service + mock_random_choice.return_value = mock_service + + result = self.client.get_any_worker_service( + ctxt, raise_if_none=raise_if_none, random_choice=True) + + mock_get_workers_for_specs.assert_called_once_with(ctxt) + mock_random_choice.assert_called_once_with(mock_service) + self.assertEqual(result, mock_service) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_any_worker_service_no_services_no_raise(self, + mock_get_workers): + mock_get_workers.return_value = [] + result = self.client.get_any_worker_service( + mock.sentinel.ctxt, raise_if_none=False) + self.assertIsNone(result) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_any_worker_service_no_services(self, mock_get_workers): + mock_get_workers.return_value = [] + self.assertRaises( + exception.NoWorkerServiceError, + self.client.get_any_worker_service, mock.sentinel.ctxt) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_any_worker_service_random_choice(self, mock_get_workers): + service_mock1 = {'id': 'test_id1'} + service_mock2 = {'id': 'test_id2'} + mock_get_workers.return_value = [service_mock1, service_mock2] + + result = self.client.get_any_worker_service( + mock.sentinel.ctxt, random_choice=True) + + self.assertIsInstance(result, dict) + + @mock.patch.object(client.SchedulerClient, '_call') + def test_get_worker_service_for_specs(self, mock_call): + ctxt = mock.sentinel.ctxt + provider_requirements = mock.sentinel.provider_requirements + region_sets = mock.sentinel.region_sets + enabled = mock.sentinel.enabled + raise_on_no_matches = mock.sentinel.raise_on_no_matches + + self.client.get_worker_service_for_specs( + ctxt, provider_requirements=provider_requirements, + region_sets=region_sets, enabled=enabled, + raise_on_no_matches=raise_on_no_matches) + + mock_call.assert_called_once_with( + ctxt, 'get_workers_for_specs', region_sets=region_sets, + enabled=enabled, provider_requirements=provider_requirements) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_worker_service_for_specs_no_services_no_raise( + self, mock_get_workers): + mock_get_workers.return_value = [] + result = self.client.get_worker_service_for_specs( + mock.sentinel.ctxt, raise_on_no_matches=False) + self.assertIsNone(result) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_worker_service_for_specs_no_services(self, mock_get_workers): + mock_get_workers.return_value = [] + self.assertRaises( + exception.NoSuitableWorkerServiceError, + self.client.get_worker_service_for_specs, mock.sentinel.ctxt) + + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + @mock.patch('random.choice') + def test_get_worker_service_for_specs_random_choice( + self, mock_random_choice, mock_get_workers): + service_mock1 = {'id': 'test_id1'} + service_mock2 = {'id': 'test_id2'} + mock_get_workers.return_value = [service_mock1, service_mock2] + mock_random_choice.return_value = service_mock1 + + result = self.client.get_worker_service_for_specs( + mock.sentinel.ctxt, random_choice=True) + + mock_random_choice.assert_called_once_with([ + service_mock1, service_mock2]) + mock_get_workers.assert_called_once_with( + mock.sentinel.ctxt, provider_requirements=None, region_sets=None, + enabled=True) + self.assertEqual(result, service_mock1) + + @mock.patch('random.choice') + @mock.patch.object(client.SchedulerClient, 'get_workers_for_specs') + def test_get_worker_service_for_specs_no_random_choice( + self, get_workers_for_specs, mock_random_choice): + mock_service = mock.MagicMock() + get_workers_for_specs.return_value = [mock_service] + + result = self.client.get_worker_service_for_specs( + mock.sentinel.ctxt, random_choice=False) + + mock_random_choice.assert_not_called() + get_workers_for_specs.assert_called_once_with( + mock.sentinel.ctxt, provider_requirements=None, region_sets=None, + enabled=True) + self.assertEqual(result, mock_service) + + @mock.patch.object(client.SchedulerClient, 'get_worker_service_for_specs') + @mock.patch.object(tasks_factory, 'get_task_runner_class') + def test_get_worker_service_for_task_different_platforms( + self, mock_get_task_runner_class, get_worker_service_for_specs): + for platform in [constants.TASK_PLATFORM_SOURCE, + constants.TASK_PLATFORM_DESTINATION, + constants.TASK_PLATFORM_BILATERAL]: + mock_get_task_runner_class.return_value.get_required_platform.\ + return_value = platform + mock_get_task_runner_class.return_value.\ + get_required_provider_types.return_value = { + constants.PROVIDER_PLATFORM_SOURCE: 'provider_type'} + get_worker_service_for_specs.return_value = {'id': 'test_id'} + + result = self.client.get_worker_service_for_task( + mock.sentinel.ctxt, self.task, self.origin_endpoint, + self.destination_endpoint, retry_period=0) + + self.assertEqual(result, {'id': 'test_id'}) + self.assertIsInstance(result, dict) + + @mock.patch.object(client.SchedulerClient, 'get_worker_service_for_specs') + @mock.patch.object(tasks_factory, 'get_task_runner_class') + def test_get_worker_service_for_task_retry( + self, mock_get_task_runner_class, mock_get_worker_service): + mock_get_task_runner_class.return_value.get_required_platform.\ + return_value = constants.TASK_PLATFORM_SOURCE + mock_get_task_runner_class.return_value.get_required_provider_types.\ + return_value = {constants.PROVIDER_PLATFORM_DESTINATION: + 'provider_type'} + mock_get_worker_service.side_effect = [Exception(), {'id': 'test_id'}] + + with self.assertLogs('coriolis.scheduler.rpc.client', + level=logging.WARN): + self.client.get_worker_service_for_task( + mock.sentinel.ctxt, self.task, self.origin_endpoint, + self.destination_endpoint, retry_period=0) + + @mock.patch.object(client.SchedulerClient, 'get_worker_service_for_specs') + @mock.patch.object(tasks_factory, 'get_task_runner_class') + def test_get_worker_service_for_task_no_suitable_worker( + self, mock_get_task_runner_class, mock_get_worker_service): + mock_get_task_runner_class.return_value.get_required_platform.\ + return_value = constants.TASK_PLATFORM_SOURCE + mock_get_task_runner_class.return_value.get_required_provider_types.\ + return_value = {constants.PROVIDER_PLATFORM_SOURCE: 'type'} + mock_get_worker_service.side_effect = [ + exception.NoSuitableWorkerServiceError()] + + self.assertRaises( + exception.NoSuitableWorkerServiceError, + self.client.get_worker_service_for_task, mock.sentinel.ctxt, + self.task, self.origin_endpoint, self.destination_endpoint, + retry_period=0, retry_count=0) diff --git a/coriolis/tests/scheduler/rpc/test_server.py b/coriolis/tests/scheduler/rpc/test_server.py index c02871384..bdee12eb4 100644 --- a/coriolis/tests/scheduler/rpc/test_server.py +++ b/coriolis/tests/scheduler/rpc/test_server.py @@ -1,15 +1,19 @@ -# Copyright 2023 Cloudbase Solutions Srl +# Copyright 2024 Cloudbase Solutions Srl # All Rights Reserved. +import logging from unittest import mock import ddt +from coriolis import constants +from coriolis.db import api as db_api from coriolis import exception from coriolis.scheduler.filters import trivial_filters from coriolis.scheduler.rpc import server from coriolis.tests import test_base from coriolis.tests import testutils +from coriolis import utils @ddt.ddt @@ -20,6 +24,122 @@ def setUp(self): super(SchedulerServerEndpointTestCase, self).setUp() self.server = server.SchedulerServerEndpoint() + @mock.patch.object(utils, "get_diagnostics_info") + def test_get_diagnostics(self, mock_get_diagnostics_info): + result = self.server.get_diagnostics(mock.sentinel.context) + + mock_get_diagnostics_info.assert_called_once_with() + self.assertEqual(result, mock_get_diagnostics_info.return_value) + + @mock.patch.object(trivial_filters, 'TopicFilter', autospec=True) + @mock.patch.object(db_api, 'get_services') + def test_get_all_worker_services(self, mock_get_services, + mock_topic_filter_cls): + mock_get_services.return_value = mock.sentinel.services + + mock_topic_filter_cls.return_value.filter_services.return_value = \ + mock.sentinel.filtered_services + + result = self.server._get_all_worker_services(mock.sentinel.context) + + mock_get_services.assert_called_once_with(mock.sentinel.context) + mock_topic_filter_cls.assert_called_once_with( + constants.WORKER_MAIN_MESSAGING_TOPIC) + mock_topic_filter_cls.return_value.filter_services.\ + assert_called_once_with(mock.sentinel.services) + + self.assertEqual(result, mock.sentinel.filtered_services) + + @mock.patch.object(db_api, 'get_services') + def test_get_all_worker_services_no_services(self, mock_get_services): + mock_get_services.return_value = [] + + self.assertRaises(exception.NoWorkerServiceError, + self.server._get_all_worker_services, + mock.sentinel.context) + + mock_get_services.assert_called_once_with(mock.sentinel.context) + + def test_get_weighted_filtered_services_no_filters(self): + services = [mock.Mock(id=1), mock.Mock(id=2)] + + with self.assertLogs('coriolis.scheduler.rpc.server', + level=logging.WARN): + result = self.server._get_weighted_filtered_services(services, + None) + expected_result = [(services[0], 100), (services[1], 100)] + self.assertEqual(result, expected_result) + + def test_get_weighted_filtered_services_with_filters_reject(self): + services = [mock.Mock(id=1), mock.Mock(id=2)] + filters = [mock.Mock(), mock.Mock()] + filters[0].rate_service.return_value = 50 + filters[1].rate_service.return_value = 0 + + self.assertRaises(exception.NoSuitableWorkerServiceError, + self.server._get_weighted_filtered_services, + services, filters) + + def test_get_weighted_filtered_services_with_filters_accept(self): + services = [mock.Mock(id=1), mock.Mock(id=2)] + filters = [mock.Mock(), mock.Mock()] + filters[0].rate_service.return_value = 50 + filters[1].rate_service.return_value = 100 + + result = self.server._get_weighted_filtered_services(services, + filters) + expected_result = [(services[0], 150), (services[1], 150)] + self.assertEqual(result, expected_result) + + @mock.patch.object(db_api, 'get_regions') + def test__filter_regions_check_all_exist_false(self, mock_get_regions): + mock_get_regions.return_value = [ + mock.Mock(id='region1', enabled=True), + mock.Mock(id='region2', enabled=True), + ] + region_ids = ['region1', 'region2'] + + result = self.server._filter_regions(None, region_ids, + check_all_exist=False) + + self.assertEqual(result, mock_get_regions.return_value) + + @mock.patch.object(db_api, 'get_regions') + def test__filter_regions_all_disabled(self, mock_get_regions): + mock_get_regions.return_value = [ + mock.Mock(id='region1', enabled=False), + mock.Mock(id='region2', enabled=False), + ] + region_ids = ['region1', 'region2'] + + result = self.server._filter_regions(None, region_ids, enabled=False) + + self.assertEqual(result, mock_get_regions.return_value) + + @mock.patch.object(db_api, 'get_regions') + def test__filter_regions_some_enabled_some_disabled(self, + mock_get_regions): + mock_get_regions.return_value = [ + mock.Mock(id='region1', enabled=True), + mock.Mock(id='region2', enabled=False), + ] + region_ids = ['region1', 'region2'] + + result = self.server._filter_regions(None, region_ids) + + self.assertEqual(result, [mock_get_regions.return_value[0]]) + + @mock.patch.object(db_api, 'get_regions') + def test__filter_regions_some_missing(self, mock_get_regions): + mock_get_regions.return_value = [ + mock.Mock(id='region1', enabled=True), + mock.Mock(id='region2', enabled=True), + ] + region_ids = ['region1', 'region2', 'region3'] + + self.assertRaises(exception.RegionNotFound, + self.server._filter_regions, None, region_ids) + @mock.patch.object(trivial_filters, 'ProviderTypesFilter', autospec=True) @mock.patch.object(trivial_filters, 'RegionsFilter', autospec=True) @mock.patch.object(trivial_filters, 'EnabledFilter', autospec=True) @@ -52,7 +172,7 @@ def test_get_workers_for_specs( provider_requirements = config.get("provider_requirements", None) # Convert the config dict to an object, skipping the providers - # providers is the only field used as dict in the code + # as it's the only field used as dict in the code config_obj = testutils.DictToObject(config, skip_attrs=["providers"]) mock_get_all_worker_services.return_value = ( config_obj.services_db or [] diff --git a/coriolis/tests/scheduler/test_scheduler_utils.py b/coriolis/tests/scheduler/test_scheduler_utils.py new file mode 100644 index 000000000..1ba07f9ff --- /dev/null +++ b/coriolis/tests/scheduler/test_scheduler_utils.py @@ -0,0 +1,125 @@ +# Copyright 2024 Cloudbase Solutions Srl +# All Rights Reserved. + +from unittest import mock + +from coriolis import constants +from coriolis import exception +from coriolis.scheduler import scheduler_utils +from coriolis.tests import test_base + + +class CoriolisTestException(Exception): + pass + + +class SchedulerUtilsTestCase(test_base.CoriolisBaseTestCase): + """Test suite for the Coriolis scheduler utils package.""" + + def setUp(self): + super(SchedulerUtilsTestCase, self).setUp() + self.scheduler_client = mock.MagicMock() + self.rpc_client_class = mock.MagicMock() + self.service = mock.MagicMock() + self.ctxt = mock.MagicMock() + + def test_get_rpc_client_for_service(self): + with mock.patch.dict( + scheduler_utils.RPC_TOPIC_TO_CLIENT_CLASS_MAP, + {constants.WORKER_MAIN_MESSAGING_TOPIC: self.rpc_client_class}, + clear=True + ): + self.service.topic = constants.WORKER_MAIN_MESSAGING_TOPIC + self.service.host = 'test_host' + + result = scheduler_utils.get_rpc_client_for_service(self.service) + + self.rpc_client_class.assert_called_once_with( + topic='coriolis_worker.test_host') + + self.assertEqual(result, self.rpc_client_class.return_value) + + def test_get_rpc_client_for_service_different_topic(self): + with mock.patch.dict( + scheduler_utils.RPC_TOPIC_TO_CLIENT_CLASS_MAP, + {mock.sentinel.topic: self.rpc_client_class}, + clear=True + ): + self.service.topic = mock.sentinel.topic + self.service.host = 'host' + + result = scheduler_utils.get_rpc_client_for_service(self.service) + + self.rpc_client_class.assert_called_once_with( + topic=mock.sentinel.topic) + + self.assertEqual(result, self.rpc_client_class.return_value) + + def test_get_rpc_client_for_service_with_exception(self): + self.service.topic = 'non-existent-topic' + self.service.host = 'host' + + self.assertRaises(exception.NotFound, + scheduler_utils.get_rpc_client_for_service, + self.service) + + def test_get_any_worker_service_no_services(self): + self.scheduler_client.get_workers_for_specs.return_value = [] + + self.assertRaises(exception.NoWorkerServiceError, + scheduler_utils.get_any_worker_service, + self.scheduler_client, self.ctxt) + + @mock.patch('coriolis.scheduler.scheduler_utils.db_api.get_service') + @mock.patch('random.choice') + def test_get_any_worker_service_random_choice(self, mock_random_choice, + get_service_mock): + service_mock1 = {'id': 'test_id1'} + service_mock2 = {'id': 'test_id2'} + + self.scheduler_client.get_workers_for_specs.return_value = [ + service_mock1, service_mock2] + + get_service_mock.return_value = [service_mock1, service_mock2] + mock_random_choice.return_value = service_mock1 + + result = scheduler_utils.get_any_worker_service( + self.scheduler_client, self.ctxt, random_choice=True) + + mock_random_choice.assert_called_once_with([ + service_mock1, service_mock2]) + get_service_mock.assert_called_once_with( + self.ctxt, service_mock1['id']) + self.assertEqual(result, get_service_mock.return_value) + + @mock.patch('coriolis.scheduler.scheduler_utils.db_api.get_service') + def test_get_any_worker_service_raw_dict(self, get_service_mock): + service_mock = {'id': 'test_id'} + + self.scheduler_client.get_workers_for_specs.return_value = [ + service_mock] + + result = scheduler_utils.get_any_worker_service( + self.scheduler_client, self.ctxt, raw_dict=True) + + get_service_mock.assert_not_called() + self.assertEqual(result, service_mock) + + def test_get_worker_rpc_for_host(self): + with mock.patch.dict( + scheduler_utils.RPC_TOPIC_TO_CLIENT_CLASS_MAP, + {constants.WORKER_MAIN_MESSAGING_TOPIC: self.rpc_client_class}, + clear=True + ): + host = 'test_host' + client_args = ('arg1', 'arg2') + client_kwargs = {'key1': 'value1', 'key2': 'value2'} + + result = scheduler_utils.get_worker_rpc_for_host( + host, *client_args, **client_kwargs) + + self.rpc_client_class.assert_called_once_with( + *client_args, topic='coriolis_worker.test_host', + **client_kwargs) + + self.assertEqual(result, self.rpc_client_class.return_value) diff --git a/coriolis/tests/services/__init__.py b/coriolis/tests/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/coriolis/tests/services/test_api.py b/coriolis/tests/services/test_api.py new file mode 100644 index 000000000..fa4d9e6ae --- /dev/null +++ b/coriolis/tests/services/test_api.py @@ -0,0 +1,43 @@ +# Copyright 2023 Cloudbase Solutions Srl +# All Rights Reserved. + + +from unittest import mock + +from coriolis.services import api as api_service +from coriolis.tests import test_base + + +class APITestCase(test_base.CoriolisBaseTestCase): + """Test suite for the Coriolis API service.""" + + def setUp(self): + super(APITestCase, self).setUp() + self.api = api_service.API() + self.rpc_client_mock = mock.Mock() + self.api._rpc_client = self.rpc_client_mock + + def test_create(self): + self.api.create('ctxt', 'host', 'binary', 'topic', 'mapped_regions', + True) + self.rpc_client_mock.register_service.assert_called_once_with( + 'ctxt', 'host', 'binary', 'topic', True, 'mapped_regions') + + def test_update(self): + self.api.update('ctxt', 'service_id', 'updated_values') + self.rpc_client_mock.update_service.assert_called_once_with( + 'ctxt', 'service_id', 'updated_values') + + def test_delete(self): + self.api.delete('ctxt', 'region_id') + self.rpc_client_mock.delete_service.assert_called_once_with( + 'ctxt', 'region_id') + + def test_get_services(self): + self.api.get_services('ctxt') + self.rpc_client_mock.get_services.assert_called_once_with('ctxt') + + def test_get_service(self): + self.api.get_service('ctxt', 'service_id') + self.rpc_client_mock.get_service.assert_called_once_with( + 'ctxt', 'service_id')