diff --git a/coriolis/tests/db/test_api.py b/coriolis/tests/db/test_api.py index 9c3980ec4..784ee86c2 100644 --- a/coriolis/tests/db/test_api.py +++ b/coriolis/tests/db/test_api.py @@ -1,29 +1,608 @@ # Copyright 2017 Cloudbase Solutions Srl # All Rights Reserved. - from unittest import mock +import uuid + +import ddt +import sqlalchemy.orm +from coriolis import constants from coriolis.db import api +from coriolis.db.sqlalchemy import api as sqlalchemy_api +from coriolis.db.sqlalchemy import models from coriolis import exception from coriolis.tests import test_base from coriolis.tests import testutils +CONTEXT_MOCK = mock.MagicMock() +DEFAULT_INSTANCE = "instance1" +DEFAULT_USER_ID = "1" +DEFAULT_PROJECT_ID = "1" +DEFAULT_TASK_INFO = {DEFAULT_INSTANCE: {"volumes_info": []}} + + +def get_valid_endpoint( + endpoint_id=None, user_id=DEFAULT_USER_ID, + project_id=DEFAULT_PROJECT_ID, connection_info=None, + endpoint_type="openstack", name="test_name", + description="Endpoint Description"): + if endpoint_id is None: + endpoint_id = str(uuid.uuid4()) + if connection_info is None: + connection_info = {"conn_info": {"secret": "info"}} + + endpoint = models.Endpoint() + endpoint.id = endpoint_id + endpoint.user_id = user_id + endpoint.project_id = project_id + endpoint.connection_info = connection_info + endpoint.type = endpoint_type + endpoint.name = name + endpoint.description = description + + return endpoint + + +class BaseDBAPITestCase(test_base.CoriolisBaseTestCase): + + valid_data = { + "user_scope": {}, + "outer_scope": {} + } + + @classmethod + def setup_scoped_data(cls, region_id, project_id="1"): + data = dict() + valid_endpoint_source = get_valid_endpoint( + endpoint_type='vmware', project_id=project_id) + cls.session.add(valid_endpoint_source) + data['source_endpoint'] = valid_endpoint_source + valid_endpoint_destination = get_valid_endpoint( + endpoint_type='openstack', project_id=project_id) + cls.session.add(valid_endpoint_destination) + data['destination_endpoint'] = valid_endpoint_destination + + valid_endpoint_region_mapping = models.EndpointRegionMapping() + valid_endpoint_region_mapping.id = str(uuid.uuid4()) + valid_endpoint_region_mapping.endpoint_id = valid_endpoint_source.id + valid_endpoint_region_mapping.region_id = region_id + cls.session.add(valid_endpoint_region_mapping) + data['endpoint_mapping'] = valid_endpoint_region_mapping + + valid_transfer = models.Transfer() + valid_transfer.id = str(uuid.uuid4()) + valid_transfer.user_id = project_id + valid_transfer.project_id = project_id + valid_transfer.base_id = valid_transfer.id + valid_transfer.scenario = constants.TRANSFER_ACTION_TYPE_REPLICA + valid_transfer.last_execution_status = ( + constants.EXECUTION_STATUS_RUNNING) + valid_transfer.executions = [] + valid_transfer.instances = [DEFAULT_INSTANCE] + valid_transfer.info = DEFAULT_TASK_INFO + valid_transfer.origin_endpoint_id = valid_endpoint_source.id + valid_transfer.destination_endpoint_id = valid_endpoint_destination.id + cls.session.add(valid_transfer) + data['transfer'] = valid_transfer + + valid_tasks_execution = models.TasksExecution() + valid_tasks_execution.id = str(uuid.uuid4()) + valid_tasks_execution.action = valid_transfer + valid_tasks_execution.status = constants.EXECUTION_STATUS_RUNNING + valid_tasks_execution.type = constants.EXECUTION_TYPE_REPLICA_EXECUTION + valid_tasks_execution.number = 1 + data['tasks_execution'] = valid_tasks_execution + + valid_task = models.Task() + valid_task.id = str(uuid.uuid4()) + valid_task.execution = valid_tasks_execution + valid_task.instance = DEFAULT_INSTANCE + valid_task.status = constants.TASK_STATUS_RUNNING + valid_task.task_type = ( + constants.TASK_TYPE_VALIDATE_REPLICA_SOURCE_INPUTS) + valid_task.index = 1 + valid_task.on_error = False + data['task'] = valid_task + + valid_progress_update = models.TaskProgressUpdate() + valid_progress_update.id = str(uuid.uuid4()) + valid_progress_update.task = valid_task + valid_progress_update.index = 1 + valid_progress_update.current_step = 0 + + valid_task_event = models.TaskEvent() + valid_task_event.id = str(uuid.uuid4()) + valid_task_event.task = valid_task + valid_task_event.level = constants.TASK_EVENT_INFO + valid_task_event.index = 1 + valid_task_event.message = "event message test" + cls.session.add(valid_tasks_execution) + data['tasks_execution'] = valid_tasks_execution + + return data + + @classmethod + def setup_database_data(cls): + cls.valid_region = models.Region() + cls.valid_region.id = str(uuid.uuid4()) + cls.valid_region.name = "region1" + cls.valid_region.enabled = True + cls.session.add(cls.valid_region) + + cls.valid_data['user_scope'] = cls.setup_scoped_data( + cls.valid_region.id) + cls.valid_data['outer_scope'] = cls.setup_scoped_data( + cls.valid_region.id, project_id="2") + cls.session.commit() + + @classmethod + def setUpClass(cls): + super(BaseDBAPITestCase, cls).setUpClass() + with mock.patch.object(sqlalchemy_api, 'CONF') as mock_conf: + mock_conf.database.connection = "sqlite://" + engine = api.get_engine() + models.BASE.metadata.create_all(engine) + cls.session = api.get_session() + cls.setup_database_data() + + def setUp(self): + super(BaseDBAPITestCase, self).setUp() + self.context = CONTEXT_MOCK + self.context.session = self.session + self.context.show_deleted = False + self.context.user = DEFAULT_USER_ID + self.context.project_id = DEFAULT_PROJECT_ID + self.context.is_admin = False + + def tearDown(self): + self.context.reset_mock() + super(BaseDBAPITestCase, self).tearDown() + + @classmethod + def tearDownClass(cls): + cls.session.rollback() + cls.session.close() + super(BaseDBAPITestCase, cls).tearDownClass() + + +@ddt.ddt +class DBAPITestCase(BaseDBAPITestCase): + """Test suite for the common Coriolis DB API.""" + + def test_get_engine(self): + self.assertEqual(api.get_engine(), api.IMPL.get_engine()) + + def test_get_session(self): + self.assertIsInstance(api.get_session(), sqlalchemy.orm.Session) + + @mock.patch.object(api, 'IMPL') + def test_db_sync(self, mock_impl): + self.assertEqual( + api.db_sync(mock.sentinel.engine, version=mock.sentinel.version), + mock_impl.db_sync.return_value) + mock_impl.db_sync.assert_called_once_with( + mock.sentinel.engine, version=mock.sentinel.version) + + @mock.patch.object(api, 'IMPL') + def test_db_version(self, mock_impl): + self.assertEqual( + api.db_version(mock.sentinel.engine), + mock_impl.db_version.return_value) + mock_impl.db_version.assert_called_once_with(mock.sentinel.engine) + + def test__session(self): + self.assertEqual(api._session(self.context), self.context.session) + + @mock.patch.object(api, 'get_session') + def test__session_no_context(self, mock_get_session): + self.assertEqual( + api._session(None), + mock_get_session.return_value) + + @mock.patch.object(api, 'get_session') + def test__session_sessionless_context(self, mock_get_session): + context = mock.Mock(session=None) + self.assertEqual( + api._session(context), + mock_get_session.return_value) + + @ddt.data( + {"kwargs": None, "expected_result": False}, + {"kwargs": {}, "expected_result": False}, + {"kwargs": {"user_id": None}, "expected_result": False}, + {"kwargs": {"user_id": "1", "project_id": None}, + "expected_result": False}, + {"kwargs": {"user_id": "1", "project_id": "1", "is_admin": True}, + "expected_result": False}, + {"kwargs": {"user_id": "1", "project_id": "1", "is_admin": False}, + "expected_result": True}, + ) + def test_is_user_context(self, data): + kwargs = data.get('kwargs') + if kwargs is None: + context = None + else: + context = mock.Mock(**data.get('kwargs', {})) + self.assertEqual( + api.is_user_context(context), data.get('expected_result')) + + @mock.patch.object(api, '_session') + def test__model_query(self, mock_session): + self.assertEqual( + api._model_query(mock.sentinel.context, mock.sentinel.model), + mock_session.return_value.query.return_value) + mock_session.assert_called_once_with( + mock.sentinel.context) + mock_session.return_value.query.assert_called_once_with( + mock.sentinel.model) + + def test__update_sqlalchemy_object_fields_non_dict_values(self): + self.assertRaises( + exception.InvalidInput, api._update_sqlalchemy_object_fields, + mock.ANY, mock.ANY, None) + + def test__update_sqlalchemy_object_fields_conflict(self): + updateable_fields = ["field1", "field2"] + values_to_update = {"field1": "value1", "field3": "value3"} + self.assertRaises( + exception.Conflict, api._update_sqlalchemy_object_fields, + mock.ANY, updateable_fields, values_to_update) + + def test__update_sqlalchemy_object_fields_invalid_obj_field(self): + self.assertRaises( + exception.InvalidInput, api._update_sqlalchemy_object_fields, + models.Endpoint, ["invalid_field"], {"invalid_field": "new_value"}) + + def test__update_sqlalchemy_object_fields(self): + obj = models.Endpoint() + obj.description = "initial test description" + new_description = "updated test description" + + api._update_sqlalchemy_object_fields( + obj, ["description"], {"description": new_description}) + self.assertEqual(obj.description, new_description) + + def test__soft_delete_aware_query_show_deleted_kwarg(self): + valid_endpoint = get_valid_endpoint() + self.session.add(valid_endpoint) + self.session.commit() + + testutils.get_wrapped_function(api.delete_endpoint)( + self.context, valid_endpoint.id) + self.context.show_deleted = False + result = api._soft_delete_aware_query( + self.context, models.Endpoint, show_deleted=True).filter( + models.Endpoint.id == valid_endpoint.id).first() + self.assertEqual(result.id, valid_endpoint.id) + self.assertIsNotNone(result.deleted_at) + + def test__soft_delete_aware_query_context_show_deleted(self): + valid_endpoint = get_valid_endpoint() + self.session.add(valid_endpoint) + self.session.commit() + + testutils.get_wrapped_function(api.delete_endpoint)( + self.context, valid_endpoint.id) + self.context.show_deleted = True + result = api._soft_delete_aware_query( + self.context, models.Endpoint).filter( + models.Endpoint.id == valid_endpoint.id).first() + self.assertEqual(result.id, valid_endpoint.id) + self.assertIsNotNone(result.deleted_at) + + +class EndpointDBAPITestCase(BaseDBAPITestCase): + + @classmethod + def setUpClass(cls): + super(EndpointDBAPITestCase, cls).setUpClass() + cls.valid_endpoint_source = cls.valid_data['user_scope'].get( + 'source_endpoint') + cls.valid_endpoint_region_mapping = cls.valid_data['user_scope'].get( + 'endpoint_mapping') + cls.outer_scope_endpoint = cls.valid_data['outer_scope'].get( + 'source_endpoint') + + def test_get_endpoints(self): + result = api.get_endpoints(self.context) + self.assertIn(self.valid_endpoint_source, result) + + def test_get_endpoints_admin(self): + self.context.is_admin = True + result = api.get_endpoints(self.context) + self.assertIn(self.outer_scope_endpoint, result) + + def test_get_endpoints_out_of_user_scope(self): + result = api.get_endpoints(self.context) + self.assertNotIn(self.outer_scope_endpoint, result) + + def test_get_endpoint(self): + result = api.get_endpoint(self.context, self.valid_endpoint_source.id) + self.assertEqual(result, self.valid_endpoint_source) + + def test_get_endpoint_admin_context(self): + self.context.is_admin = True + result = api.get_endpoint(self.context, self.outer_scope_endpoint.id) + self.assertEqual(result, self.outer_scope_endpoint) + + def test_get_endpoint_out_of_user_scope(self): + result = api.get_endpoint(self.context, self.outer_scope_endpoint.id) + self.assertIsNone(result) + + def test_add_endpoint(self): + self.context.user = "2" + self.context.project_id = "2" + new_endpoint_id = str(uuid.uuid4()) + new_endpoint = get_valid_endpoint( + endpoint_id=new_endpoint_id, + connection_info={"conn_info": {"new": "info"}}, + endpoint_type="vmware", name="new_endpoint", + description="New Endpoint") + api.add_endpoint(self.context, new_endpoint) + result = api.get_endpoint(self.context, new_endpoint_id) + self.assertEqual(result, new_endpoint) + + def test_update_endpoint_not_found(self): + self.assertRaises( + exception.NotFound, api.update_endpoint, + self.context, "invalid_id", mock.ANY) + + def test_update_endpoint_invalid_values(self): + self.assertRaises( + exception.InvalidInput, api.update_endpoint, + self.context, self.valid_endpoint_source.id, None) + + def test_update_endpoint_invalid_column(self): + self.assertRaises( + exception.Conflict, api.update_endpoint, + self.context, self.valid_endpoint_source.id, {"type": "openstack"}) + + def test_update_endpoint_region_not_found(self): + self.assertRaises( + exception.NotFound, api.update_endpoint, self.context, + self.valid_endpoint_source.id, + {"mapped_regions": ["invalid_region_id"]}) + + def test_update_endpoint(self): + new_region_id = str(uuid.uuid4()) + new_endpoint_name = "new_name" + new_region = models.Region() + new_region.id = new_region_id + new_region.name = "new_region" + new_region.enabled = True + self.session.add(new_region) + self.session.commit() + + api.update_endpoint( + self.context, self.valid_endpoint_source.id, + {"mapped_regions": [new_region_id], "name": new_endpoint_name}) + result = api.get_endpoint(self.context, self.valid_endpoint_source.id) + old_endpoint_region_mapping = api.get_endpoint_region_mapping( + self.context, self.valid_endpoint_source.id, self.valid_region.id) + new_endpoint_region_mapping = api.get_endpoint_region_mapping( + self.context, self.valid_endpoint_source.id, new_region_id)[0] + self.assertEqual(result.name, new_endpoint_name) + self.assertEqual(old_endpoint_region_mapping, []) + self.assertEqual(new_endpoint_region_mapping.region_id, new_region_id) + self.assertEqual( + new_endpoint_region_mapping.endpoint_id, + self.valid_endpoint_source.id) + + @mock.patch.object(api, 'delete_endpoint_region_mapping') + @mock.patch.object(api, 'add_endpoint_region_mapping') + @mock.patch.object(api, 'get_region') + @mock.patch.object(api, '_update_sqlalchemy_object_fields') + def test_update_endpoint_remapping_failure( + self, mock_update_obj, mock_get_region, mock_add_mapping, + mock_delete_mapping): + mock_add_mapping.side_effect = [Exception, None] + + self.assertRaises( + Exception, api.update_endpoint, + self.context, self.valid_endpoint_source.id, + {"mapped_regions": [mock.sentinel.region_id]}) + mock_get_region.assert_called_with( + self.context, mock.sentinel.region_id) + + mock_delete_mapping.side_effect = Exception + mock_update_obj.side_effect = Exception + self.assertRaises( + Exception, api.update_endpoint, self.context, + self.valid_endpoint_source.id, + {"mapped_regions": [mock.sentinel.region_id]}) + + def test_delete_endpoint(self): + new_endpoint = get_valid_endpoint() + new_endpoint_id = new_endpoint.id + new_endpoint_region_mapping = self.valid_endpoint_region_mapping + new_endpoint_region_mapping.endpoint_id = new_endpoint_id + api.add_endpoint(self.context, new_endpoint) + + api.delete_endpoint(self.context, new_endpoint_id) + result = api.get_endpoint(self.context, new_endpoint_id) + mappings = api.get_endpoint_region_mapping( + self.context, new_endpoint_id, self.valid_region.id) + self.assertIsNone(result) + self.assertEqual(mappings, []) + + def test_delete_endpoint_not_found(self): + self.assertRaises( + exception.NotFound, api.delete_endpoint, self.context, "no_id") + + def test_delete_endpoint_admin_context(self): + self.context.is_admin = True + self.context.show_deleted = True + new_outer_scope_endpoint = get_valid_endpoint() + new_outer_scope_endpoint.user_id = "3" + new_outer_scope_endpoint.project_id = "3" + api.add_endpoint(self.context, new_outer_scope_endpoint) + + api.delete_endpoint( + self.context, new_outer_scope_endpoint.id) + result = api.get_endpoint(self.context, new_outer_scope_endpoint.id) + self.assertIsNotNone(result.deleted_at) + + def test_delete_endpoint_out_of_user_scope(self): + new_outer_scope_endpoint = get_valid_endpoint( + user_id="3", project_id="3") + self.session.add(new_outer_scope_endpoint) + self.session.commit() + + self.assertRaises( + exception.NotFound, api.delete_endpoint, self.context, + new_outer_scope_endpoint.id) + + +class TransferTasksExecutionDBAPITestCase(BaseDBAPITestCase): + + @classmethod + def setUpClass(cls): + super(TransferTasksExecutionDBAPITestCase, cls).setUpClass() + cls.valid_transfer = cls.valid_data['user_scope'].get('transfer') + cls.valid_task = cls.valid_data['user_scope'].get('task') + cls.valid_tasks_execution = cls.valid_data['user_scope'].get( + 'tasks_execution') + cls.outer_scope_transfer = cls.valid_data['outer_scope'].get( + 'transfer') + cls.outer_scope_tasks_execution = cls.valid_data['outer_scope'].get( + "tasks_execution") + + @staticmethod + def _create_dummy_execution(action): + new_tasks_execution = models.TasksExecution() + new_tasks_execution.id = str(uuid.uuid4()) + new_tasks_execution.action = action + new_tasks_execution.status = constants.EXECUTION_STATUS_UNEXECUTED + new_tasks_execution.type = constants.EXECUTION_TYPE_REPLICA_EXECUTION + new_tasks_execution.number = 0 + + return new_tasks_execution + + def test_get_transfer_tasks_executions_include_info(self): + result = api.get_transfer_tasks_executions( + self.context, self.valid_transfer.id, include_task_info=True) + self.assertTrue(hasattr(result[0].action, 'info')) + + def test_get_transfer_tasks_executions_include_tasks(self): + result = api.get_transfer_tasks_executions( + self.context, self.valid_transfer.id, include_tasks=True) + tasks = [] + for e in result: + tasks.extend(e.tasks) + + self.assertIn(self.valid_task, tasks) + + def test_get_transfer_tasks_executions_to_dict(self): + result = api.get_transfer_tasks_executions( + self.context, self.valid_transfer.id, to_dict=True) + execution_ids = [e['id'] for e in result] + self.assertIn(self.valid_tasks_execution.id, execution_ids) + + def test_get_transfer_tasks_executions(self): + result = api.get_transfer_tasks_executions( + self.context, self.valid_transfer.id) + self.assertIn(self.valid_tasks_execution, result) + + def test_get_transfer_tasks_executions_admin(self): + self.context.is_admin = True + result = api.get_transfer_tasks_executions( + self.context, self.outer_scope_transfer.id) + self.assertIn(self.outer_scope_tasks_execution, result) + + def test_get_transfer_tasks_execution_out_of_user_scope(self): + result = api.get_transfer_tasks_executions( + self.context, self.outer_scope_transfer.id) + self.assertEqual(result, []) + + def test_get_transfer_tasks_execution(self): + result = api.get_transfer_tasks_execution( + self.context, self.valid_transfer.id, + self.valid_tasks_execution.id) + self.assertEqual(result, self.valid_tasks_execution) + + def test_get_transfer_tasks_execution_admin(self): + self.context.is_admin = True + result = api.get_transfer_tasks_execution( + self.context, self.outer_scope_transfer.id, + self.outer_scope_tasks_execution.id) + self.assertEqual(result, self.outer_scope_tasks_execution) + + def test_get_transfer_tasks_execution_out_of_user_context(self): + result = api.get_transfer_tasks_execution( + self.context, self.outer_scope_transfer.id, + self.outer_scope_tasks_execution.id) + self.assertIsNone(result) + + def test_get_transfer_tasks_execution_include_task_info(self): + result = api.get_transfer_tasks_execution( + self.context, self.valid_transfer.id, + self.valid_tasks_execution.id, include_task_info=True) + self.assertTrue(hasattr(result.action, 'info')) + + def test_get_transfer_tasks_execution_to_dict(self): + result = api.get_transfer_tasks_execution( + self.context, self.valid_transfer.id, + self.valid_tasks_execution.id, to_dict=True) + self.assertEqual(result['id'], self.valid_tasks_execution.id) + + def test_add_transfer_tasks_execution(self): + new_tasks_execution = self._create_dummy_execution(self.valid_transfer) + + api.add_transfer_tasks_execution(self.context, new_tasks_execution) + result = api.get_transfer_tasks_execution( + self.context, self.valid_transfer.id, new_tasks_execution.id) + self.assertEqual(new_tasks_execution, result) + self.assertGreater(result.number, 0) + + def test_add_transfer_tasks_execution_admin(self): + self.context.is_admin = True + new_tasks_execution = self._create_dummy_execution( + self.outer_scope_transfer) + api.add_transfer_tasks_execution(self.context, new_tasks_execution) + result = api.get_transfer_tasks_execution( + self.context, self.outer_scope_transfer.id, new_tasks_execution.id) + self.assertEqual(new_tasks_execution, result) + + def test_add_transfer_tasks_execution_out_of_user_context(self): + new_tasks_execution = self._create_dummy_execution( + self.outer_scope_transfer) + self.assertRaises( + exception.NotAuthorized, api.add_transfer_tasks_execution, + self.context, new_tasks_execution) -class DBAPITestCase(test_base.CoriolisBaseTestCase): - """Test suite for the Coriolis DB API.""" + def test_delete_transfer_tasks_execution(self): + new_tasks_execution = self._create_dummy_execution(self.valid_transfer) + api.add_transfer_tasks_execution(self.context, new_tasks_execution) + api.delete_transfer_tasks_execution( + self.context, new_tasks_execution.id) + result = api.get_transfer_tasks_execution( + self.context, self.valid_transfer.id, new_tasks_execution.id) + self.assertIsNone(result) - @mock.patch.object(api, 'get_endpoint') - def test_update_endpoint_not_found(self, mock_get_endpoint): - mock_get_endpoint.return_value = None + def test_delete_transfer_tasks_execution_admin(self): + self.context.is_admin = True + new_tasks_execution = self._create_dummy_execution( + self.outer_scope_transfer) + api.add_transfer_tasks_execution(self.context, new_tasks_execution) + api.delete_transfer_tasks_execution( + self.context, new_tasks_execution.id) + result = api.get_transfer_tasks_execution( + self.context, self.outer_scope_transfer.id, new_tasks_execution.id) + self.assertIsNone(result) - # We only need to test the unwrapped functions. Without this, - # when calling a coriolis.db.api function, it will try to - # establish an SQL connection. - update_endpoint = testutils.get_wrapped_function(api.update_endpoint) + def test_delete_transfer_tasks_execution_out_of_user_scope(self): + self.context.is_admin = True + new_tasks_execution = self._create_dummy_execution( + self.outer_scope_transfer) + api.add_transfer_tasks_execution(self.context, new_tasks_execution) - self.assertRaises(exception.NotFound, update_endpoint, - mock.sentinel.context, mock.sentinel.endpoint_id, - mock.sentinel.updated_values) + self.context.is_admin = False + self.assertRaises( + exception.NotAuthorized, api.delete_transfer_tasks_execution, + self.context, new_tasks_execution.id) - mock_get_endpoint.assert_called_once_with(mock.sentinel.context, - mock.sentinel.endpoint_id) + def test_delete_transfer_tasks_execution_not_found(self): + self.context.is_admin = True + self.assertRaises( + exception.NotFound, api.delete_transfer_tasks_execution, + self.context, "invalid_id") diff --git a/tox.ini b/tox.ini index e4c89e17f..432a134f1 100644 --- a/tox.ini +++ b/tox.ini @@ -36,5 +36,5 @@ omit = coriolis/tests/* # E125 is deliberately excluded. See https://github.com/jcrocholl/pep8/issues/126 # E251 Skipped due to https://github.com/jcrocholl/pep8/issues/301 -ignore = E125,E251,W503,W504,E305,E731,E117,W605,F632,H401,H403,H404,H405 +ignore = E125,E251,W503,W504,E305,E731,E117,W605,F632,H401,H403,H404,H405,H202 exclude = .venv,.git,.tox,dist,doc,*openstack/common*,*lib/python*,*egg,build,tools