diff --git a/turbinia/state_manager.py b/turbinia/state_manager.py index 416ea4a94..9dfc4bae5 100644 --- a/turbinia/state_manager.py +++ b/turbinia/state_manager.py @@ -22,6 +22,7 @@ import json import logging import sys +from celery import states as celery_states from datetime import datetime from datetime import timedelta from typing import Any, List, Dict, Optional @@ -269,7 +270,6 @@ def update_request_task(self, task) -> None: task (TurbiniaTask): Turbinia task object. """ request_key = self.redis_client.build_key_name('request', task.request_id) - task_key = self.redis_client.build_key_name('task', task.id) try: self.redis_client.add_to_list(request_key, 'task_ids', task.id) request_last_update = datetime.strptime( @@ -291,12 +291,12 @@ def update_request_task(self, task) -> None: elif task.result.successful is False: self.redis_client.add_to_list(request_key, 'failed_tasks', task.id) statuses_to_remove.remove('failed_tasks') - task_status = self.redis_client.get_attribute(task_key, 'status') - if task_status and 'Task is running' in task_status: + task_status = task.celery_state + if task_status == celery_states.STARTED: self.redis_client.add_to_list(request_key, 'running_tasks', task.id) statuses_to_remove.remove('running_tasks') - elif (task_status is None or task_status == 'queued' or - task_status == 'pending'): + elif (task_status is None or task_status == celery_states.RECEIVED or + task_status == celery_states.PENDING): self.redis_client.add_to_list(request_key, 'queued_tasks', task.id) statuses_to_remove.remove('queued_tasks') for status_name in statuses_to_remove: diff --git a/turbinia/task_manager.py b/turbinia/task_manager.py index 2c46f1567..6f55098b5 100644 --- a/turbinia/task_manager.py +++ b/turbinia/task_manager.py @@ -687,6 +687,14 @@ def _backend_setup(self, *args, **kwargs): self.celery_runner = self.celery.app.task( task_utils.task_runner, name='task_runner') + def _get_worker_name(self, celery_stub): + """Gets the Celery worker name from the AsyncResult object.""" + worker_name = celery_stub.result.get('hostname', None) + if worker_name: + # example hostname: celery@turbinia-hostname + worker_name = worker_name.split('@')[1] + return worker_name + def process_tasks(self): """Determine the current state of our tasks. @@ -700,28 +708,33 @@ def process_tasks(self): celery_task = task.stub # ref: https://docs.celeryq.dev/en/stable/reference/celery.states.html if not celery_task: - log.debug(f'Task {task.stub.task_id:s} not yet created.') + log.info(f'Task {task.stub.task_id:s} not yet created.') check_timeout = True elif celery_task.status == celery_states.STARTED: - # Task status will be set to running when the worker executes run_wrapper() - log.debug(f'Task {celery_task.id:s} not finished.') + log.warning(f'Task {celery_task.id:s} {task.id} started.') + task.worker_name = self._get_worker_name(celery_task) + task.celery_state = celery_states.STARTED check_timeout = True elif celery_task.status == celery_states.FAILURE: - log.warning(f'Task {celery_task.id:s} failed.') + log.info(f'Task {celery_task.id:s} failed.') + task.celery_state = celery_states.FAILURE self.close_failed_task(task) completed_tasks.append(task) elif celery_task.status == celery_states.SUCCESS: + task.celery_state = celery_states.SUCCESS task.result = workers.TurbiniaTaskResult.deserialize(celery_task.result) completed_tasks.append(task) elif celery_task.status == celery_states.PENDING: - task.status = 'pending' + log.info(f'Task {celery_task.id:s} is pending.') + task.celery_state = celery_states.PENDING check_timeout = True - log.debug(f'Task {celery_task.id:s} is pending.') elif celery_task.status == celery_states.RECEIVED: - task.status = 'queued' + log.info(f'Task {celery_task.id:s} is queued.') + task.worker_name = self._get_worker_name(celery_task) + task.celery_state = celery_states.RECEIVED check_timeout = True - log.debug(f'Task {celery_task.id:s} is queued.') elif celery_task.status == celery_states.REVOKED: + task.celery_state = celery_states.REVOKED message = ( f'Celery task {celery_task.id:s} associated with Turbinia ' f'task {task.id} was revoked. This could be caused if the task is ' @@ -746,6 +759,9 @@ def process_tasks(self): task = self.timeout_task(task, timeout) completed_tasks.append(task) + # Update task metadata so we have an up to date state. + self.state_manager.update_task(task) + outstanding_task_count = len(self.tasks) - len(completed_tasks) if outstanding_task_count > 0: log.info(f'{outstanding_task_count:d} Tasks still outstanding.') diff --git a/turbinia/tcelery.py b/turbinia/tcelery.py index 9d0b3b200..1e006a815 100644 --- a/turbinia/tcelery.py +++ b/turbinia/tcelery.py @@ -57,6 +57,7 @@ def setup(self): task_default_queue=config.INSTANCE_ID, # Re-queue task if Celery worker abruptly exists task_reject_on_worker_lost=True, + task_track_started=True, worker_cancel_long_running_tasks_on_connection_loss=True, worker_concurrency=1, worker_prefetch_multiplier=1, diff --git a/turbinia/workers/__init__.py b/turbinia/workers/__init__.py index 3b11d6ab9..8ebfb9c78 100644 --- a/turbinia/workers/__init__.py +++ b/turbinia/workers/__init__.py @@ -129,8 +129,8 @@ class TurbiniaTaskResult: # The list of attributes that we will persist into storage STORED_ATTRIBUTES = [ - 'worker_name', 'report_data', 'report_priority', 'run_time', 'status', - 'saved_paths', 'successful', 'evidence_size' + 'report_data', 'report_priority', 'run_time', 'status', 'saved_paths', + 'successful', 'evidence_size', 'worker_name' ] def __init__( @@ -454,7 +454,7 @@ class TurbiniaTask: STORED_ATTRIBUTES = [ 'id', 'job_id', 'job_name', 'start_time', 'last_update', 'name', 'evidence_name', 'evidence_id', 'request_id', 'requester', 'group_name', - 'reason', 'group_id', 'celery_id' + 'reason', 'group_id', 'celery_id', 'celery_state', 'worker_name' ] # The list of evidence states that are required by a Task in order to run. @@ -483,6 +483,7 @@ def __init__( self.id = uuid.uuid4().hex self.celery_id = None + self.celery_state = None self.is_finalize_task = False self.job_id = None self.job_name = None @@ -508,6 +509,7 @@ def __init__( self.group_name = group_name self.reason = reason self.group_id = group_id + self.worker_name = None def serialize(self): """Converts the TurbiniaTask object into a serializable dict. @@ -821,6 +823,7 @@ def setup(self, evidence): TurbiniaException: If the evidence can not be found. """ self.setup_metrics() + self.state_manager = state_manager.get_state_manager() self.output_manager.setup(self.name, self.id, self.request_id) self.tmp_dir, self.output_dir = self.output_manager.get_local_output_dirs() if not self.result: @@ -1039,21 +1042,10 @@ def run_wrapper(self, evidence): log.debug(f'Task {self.name:s} {self.id:s} awaiting execution') failure_message = None - worker_name = platform.node() + self.worker_name = platform.node() try: evidence = evidence_decode(evidence) self.result = self.setup(evidence) - # Call update_task_status to update status - # We cannot call update_task() here since it will clobber previously - # stored data by the Turbinia server when the task was created, which is - # not present in the TurbiniaTask object the worker currently has in its - # runtime. - self.update_task_status(self, 'queued') - # Beucase of the same reason, we perform a single attribute update - # for the worker name. - task_key = self.state_manager.redis_client.build_key_name('task', self.id) - self.state_manager.redis_client.set_attribute( - task_key, 'worker_name', json.dumps(worker_name)) turbinia_worker_tasks_queued_total.inc() task_runtime_metrics = self.get_metrics() except TurbiniaException as exception: @@ -1108,9 +1100,6 @@ def run_wrapper(self, evidence): self._evidence_config = evidence.config self.task_config = self.get_task_recipe(evidence.config) self.worker_start_time = datetime.now() - # Update task status so we know which worker the task executed on. - updated_status = f'Task is running on {worker_name}' - self.update_task_status(self, updated_status) self.result = self.run(evidence, self.result) # pylint: disable=broad-except @@ -1161,7 +1150,7 @@ def run_wrapper(self, evidence): failure_message = ( 'Task Result was auto-closed from task executor on {0:s} likely ' 'due to previous failures. Previous status: [{1:s}]'.format( - self.result.worker_name, status)) + self.worker_name, status)) self.result.log(failure_message) try: self.result.close(self, False, failure_message) @@ -1197,23 +1186,3 @@ def run(self, evidence, result): TurbiniaTaskResult object. """ raise NotImplementedError - - def update_task_status(self, task, status=None): - """Updates the task status and pushes it directly to datastore. - - Args: - task (TurbiniaTask): The calling Task object - status (str): Brief word or phrase for Task state. If not supplied, the - existing Task status will be used. - """ - if not status: - return - if not self.state_manager: - self.state_manager = state_manager.get_state_manager() - if self.state_manager: - task_key = self.state_manager.redis_client.build_key_name('task', task.id) - self.state_manager.redis_client.set_attribute( - task_key, 'status', json.dumps(status)) - self.state_manager.update_request_task(task) - else: - log.info('No state_manager initialized, not updating Task info') diff --git a/turbinia/workers/workers_test.py b/turbinia/workers/workers_test.py index 321ba1dc2..7331b53ec 100644 --- a/turbinia/workers/workers_test.py +++ b/turbinia/workers/workers_test.py @@ -106,7 +106,6 @@ def setResults( self.result.input_evidence = evidence.RawDisk() self.result.status = 'TestStatus' - self.update_task_status = mock.MagicMock() self.result.close = mock.MagicMock() self.task.setup = mock.MagicMock(return_value=setup) self.result.worker_name = 'worker1' @@ -147,8 +146,10 @@ def testTurbiniaTaskSerialize(self): self.assertEqual(out_obj.__dict__, self.plaso_task.__dict__) @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskRunWrapper(self, _, __): + def testTurbiniaTaskRunWrapper( + self, + _, + ): """Test that the run wrapper executes task run.""" self.setResults() self.result.closed = True @@ -159,8 +160,10 @@ def testTurbiniaTaskRunWrapper(self, _, __): self.result.close.assert_not_called() @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskRunWrapperAutoClose(self, _, __): + def testTurbiniaTaskRunWrapperAutoClose( + self, + _, + ): """Test that the run wrapper closes the task.""" self.setResults() new_result = self.task.run_wrapper(self.evidence.__dict__) @@ -170,8 +173,7 @@ def testTurbiniaTaskRunWrapperAutoClose(self, _, __): @mock.patch('turbinia.state_manager.get_state_manager') @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskRunWrapperBadResult(self, _, __, ___): + def testTurbiniaTaskRunWrapperBadResult(self, _, __): """Test that the run wrapper recovers from run returning bad result.""" bad_result = 'Not a TurbiniaTaskResult' checked_result = TurbiniaTaskResult(base_output_dir=self.base_output_dir) @@ -185,8 +187,7 @@ def testTurbiniaTaskRunWrapperBadResult(self, _, __, ___): self.assertIn('CheckedResult', new_result.status) @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskJobUnavailable(self, _, __): + def testTurbiniaTaskJobUnavailable(self, _): """Test that the run wrapper can fail if the job doesn't exist.""" self.setResults() self.task.job_name = 'non_exist' @@ -198,8 +199,7 @@ def testTurbiniaTaskJobUnavailable(self, _, __): self.assertEqual(new_result.status, canary_status) @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskRunWrapperExceptionThrown(self, _, __): + def testTurbiniaTaskRunWrapperExceptionThrown(self, _): """Test that the run wrapper recovers from run throwing an exception.""" self.setResults() self.task.run = mock.MagicMock(side_effect=TurbiniaException) @@ -253,9 +253,7 @@ def testTurbiniaTaskValidateResultBadResult(self, _, __): @mock.patch('turbinia.workers.evidence_decode') @mock.patch('turbinia.redis_client.RedisClient.set_attribute') - @mock.patch('turbinia.workers.TurbiniaTask.update_task_status') - def testTurbiniaTaskEvidenceValidationFailure( - self, _, __, evidence_decode_mock): + def testTurbiniaTaskEvidenceValidationFailure(self, _, evidence_decode_mock): """Tests Task fails when evidence validation fails.""" self.setResults() test_evidence = evidence.RawDisk() diff --git a/web/src/components/TaskDetails.vue b/web/src/components/TaskDetails.vue index b217c74ca..f8b9d9f4d 100644 --- a/web/src/components/TaskDetails.vue +++ b/web/src/components/TaskDetails.vue @@ -39,11 +39,17 @@ limitations under the License. {{ taskDetails.status }} - -
- {{ taskDetails.status }} -
-
Task {{ taskDetails.id }} is pending
+ + {{ taskDetails.status }} + + + Task {{ taskDetails.id }} is running on {{ taskDetails.worker_name }} + + + Task {{ taskDetails.id }} is pending. + + + Task {{ taskDetails.id }} is queued. {{ taskDetails.status }} @@ -72,6 +78,12 @@ limitations under the License.
N/A
+ +
+ {{ taskDetails.celery_state }} +
+
N/A
+
{{ taskDetails.evidence_id }} diff --git a/web/src/components/TaskList.vue b/web/src/components/TaskList.vue index 0349a5716..64d4b881e 100644 --- a/web/src/components/TaskList.vue +++ b/web/src/components/TaskList.vue @@ -73,14 +73,20 @@ export default { let data = response.data['tasks'] for (const task in data) { let task_dict = data[task] - let taskStatusTemp = task_dict.status + let taskStatusTemp = task_dict.celery_state // As pending status requests show as null or pending - if (taskStatusTemp === null || taskStatusTemp === "pending") { + if (taskStatusTemp === null || taskStatusTemp === "PENDING") { taskStatusTemp = 'is pending on server.' } - else if (taskStatusTemp == "queued") { + else if (taskStatusTemp == "RECEIVED") { taskStatusTemp = 'is queued for execution.' } + else if (taskStatusTemp == "STARTED") { + taskStatusTemp = 'is running on ' + task_dict.worker_name + } + else { + taskStatusTemp = task_dict.status + } if (this.filterJobs.length > 0) { let jobName = task_dict.job_name.toLowerCase() if (this.radioFilter && !this.filterJobs.includes(jobName)) {