From 957f1fb1baad3739d0213ccee98124d5a4be973f Mon Sep 17 00:00:00 2001 From: George Pollard Date: Wed, 25 Oct 2023 12:07:50 +1300 Subject: [PATCH] Improve slow Task queries (#3594) Querying Tasks only by TaskID is slow as it requires querying many partitions. Querying Task by JobID and TaskID is fast as it is a point lookup. 1. rename existing `GetTaskById` to `GetTaskByIdSlow` 2. include JobID in event data from agent (new properties are optional for back-compat) 3. include JobID in NodeTask entity (if this is missing upon lookup we will fall back to `GetTaskByIdSlow`) --- .../ApiService/Functions/AgentCanSchedule.cs | 7 ++++- .../ApiService/Functions/AgentEvents.cs | 28 ++++++++++++++----- .../ApiService/Functions/QueueJobResult.cs | 2 +- .../Functions/QueueTaskHeartbeat.cs | 13 +++++---- src/ApiService/ApiService/Functions/Tasks.cs | 7 ++--- .../ApiService/Functions/TimerRetention.cs | 2 +- .../ApiService/OneFuzzTypes/Model.cs | 10 ++++--- .../ApiService/OneFuzzTypes/Requests.cs | 12 +++++++- .../ApiService/onefuzzlib/NodeOperations.cs | 21 +++++++++++--- .../ApiService/onefuzzlib/ReproOperations.cs | 8 +++--- .../ApiService/onefuzzlib/TaskOperations.cs | 15 +++------- .../IntegrationTests/AgentEventsTests.cs | 11 +++++--- .../IntegrationTests/_FunctionTestBase.cs | 3 +- src/ApiService/Tests/RequestsTests.cs | 18 ++++++++---- src/agent/onefuzz-agent/src/agent.rs | 13 +++++++-- src/agent/onefuzz-agent/src/agent/tests.rs | 12 ++++++-- src/agent/onefuzz-agent/src/coordinator.rs | 12 +++++++- src/agent/onefuzz-agent/src/debug.rs | 17 +++++++++-- src/agent/onefuzz-agent/src/main.rs | 1 + src/agent/onefuzz-agent/src/work.rs | 4 --- src/agent/onefuzz-agent/src/worker.rs | 4 +++ src/agent/onefuzz-agent/src/worker/tests.rs | 4 ++- 22 files changed, 156 insertions(+), 68 deletions(-) diff --git a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs index 4a916dc60d..887b067cc1 100644 --- a/src/ApiService/ApiService/Functions/AgentCanSchedule.cs +++ b/src/ApiService/ApiService/Functions/AgentCanSchedule.cs @@ -40,7 +40,12 @@ public async Async.Task Run( var allowed = canProcessNewWork.IsAllowed; var reason = canProcessNewWork.Reason; - var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId); + var task = await + (canScheduleRequest.JobId.HasValue + ? _context.TaskOperations.GetByJobIdAndTaskId(canScheduleRequest.JobId.Value, canScheduleRequest.TaskId) + // old agent, fall back + : _context.TaskOperations.GetByTaskIdSlow(canScheduleRequest.TaskId)); + var workStopped = task == null || task.State.ShuttingDown(); if (!allowed) { _log.LogInformation("Node cannot process new work {PoolName} {ScalesetId} - {MachineId} ", node.PoolName, node.ScalesetId, node.MachineId); diff --git a/src/ApiService/ApiService/Functions/AgentEvents.cs b/src/ApiService/ApiService/Functions/AgentEvents.cs index 8a557b5cd4..18fb181bd5 100644 --- a/src/ApiService/ApiService/Functions/AgentEvents.cs +++ b/src/ApiService/ApiService/Functions/AgentEvents.cs @@ -110,18 +110,23 @@ public async Async.Task Run( _log.LogInformation("node now available for work: {MachineId}", machineId); } else if (ev.State == NodeState.SettingUp) { if (ev.Data is NodeSettingUpEventData settingUpData) { - if (!settingUpData.Tasks.Any()) { + if (settingUpData.Tasks?.Any() == false || settingUpData.TaskData?.Any() == false) { return Error.Create(ErrorCode.INVALID_REQUEST, $"setup without tasks. machine_id: {machineId}" ); } - foreach (var taskId in settingUpData.Tasks) { - var task = await _context.TaskOperations.GetByTaskId(taskId); + var tasks = + settingUpData.Tasks is not null + ? settingUpData.Tasks.Select(t => (t, _context.TaskOperations.GetByTaskIdSlow(t))) + : settingUpData.TaskData!.Select(t => (t.TaskId, _context.TaskOperations.GetByJobIdAndTaskId(t.JobId, t.TaskId))); + + foreach (var (id, taskTask) in tasks) { + var task = await taskTask; if (task is null) { return Error.Create( ErrorCode.INVALID_REQUEST, - $"unable to find task: {taskId}"); + $"unable to find task: {id}"); } _log.LogInformation("node starting task. {MachineId} {JobId} {TaskId}", machineId, task.JobId, task.TaskId); @@ -139,6 +144,7 @@ public async Async.Task Run( var nodeTask = new NodeTasks( MachineId: machineId, TaskId: task.TaskId, + JobId: task.JobId, State: NodeTaskState.SettingUp); var r = await _context.NodeTasksOperations.Replace(nodeTask); if (!r.IsOk) { @@ -183,7 +189,10 @@ public async Async.Task Run( private async Async.Task OnWorkerEventRunning(Guid machineId, WorkerRunningEvent running) { var (task, node) = await ( - _context.TaskOperations.GetByTaskId(running.TaskId), + (running.JobId.HasValue + ? _context.TaskOperations.GetByJobIdAndTaskId(running.JobId.Value, running.TaskId) + // old agent, fallback + : _context.TaskOperations.GetByTaskIdSlow(running.TaskId)), _context.NodeOperations.GetByMachineId(machineId)); if (task is null) { @@ -202,6 +211,7 @@ public async Async.Task Run( var nodeTask = new NodeTasks( MachineId: machineId, TaskId: running.TaskId, + JobId: running.JobId, State: NodeTaskState.Running); var r = await _context.NodeTasksOperations.Replace(nodeTask); if (!r.IsOk) { @@ -231,8 +241,12 @@ public async Async.Task Run( } private async Async.Task OnWorkerEventDone(Guid machineId, WorkerDoneEvent done) { + var (task, node) = await ( - _context.TaskOperations.GetByTaskId(done.TaskId), + (done.JobId.HasValue + ? _context.TaskOperations.GetByJobIdAndTaskId(done.JobId.Value, done.TaskId) + // old agent, fall back + : _context.TaskOperations.GetByTaskIdSlow(done.TaskId)), _context.NodeOperations.GetByMachineId(machineId)); if (task is null) { @@ -285,7 +299,7 @@ await _context.TaskOperations.MarkFailed( } if (!node.DebugKeepNode) { - var r = await _context.NodeTasksOperations.Delete(new NodeTasks(machineId, done.TaskId)); + var r = await _context.NodeTasksOperations.Delete(new NodeTasks(machineId, done.TaskId, done.JobId)); if (!r.IsOk) { _log.AddHttpStatus(r.ErrorV); _log.LogError("failed to deleting node task {TaskId} for: {MachineId} since DebugKeepNode is false", done.TaskId, machineId); diff --git a/src/ApiService/ApiService/Functions/QueueJobResult.cs b/src/ApiService/ApiService/Functions/QueueJobResult.cs index d781a4d1e1..5e3bec0048 100644 --- a/src/ApiService/ApiService/Functions/QueueJobResult.cs +++ b/src/ApiService/ApiService/Functions/QueueJobResult.cs @@ -23,7 +23,7 @@ public async Async.Task Run([QueueTrigger("job-result", Connection = "AzureWebJo _log.LogInformation("job result: {msg}", msg); var jr = JsonSerializer.Deserialize(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}"); - var task = await _tasks.GetByTaskId(jr.TaskId); + var task = await _tasks.GetByTaskIdSlow(jr.TaskId); if (task == null) { _log.LogWarning("invalid {TaskId}", jr.TaskId); return; diff --git a/src/ApiService/ApiService/Functions/QueueTaskHeartbeat.cs b/src/ApiService/ApiService/Functions/QueueTaskHeartbeat.cs index 850e77f71f..8781367902 100644 --- a/src/ApiService/ApiService/Functions/QueueTaskHeartbeat.cs +++ b/src/ApiService/ApiService/Functions/QueueTaskHeartbeat.cs @@ -25,17 +25,18 @@ public async Async.Task Run([QueueTrigger("task-heartbeat", Connection = "AzureW _log.LogInformation("heartbeat: {msg}", msg); var hb = JsonSerializer.Deserialize(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}"); - var task = await _tasks.GetByTaskId(hb.TaskId); - if (task == null) { - _log.LogWarning("invalid {TaskId}", hb.TaskId); + var job = await _jobs.Get(hb.JobId); + if (job == null) { + _log.LogWarning("invalid {JobId}", hb.JobId); return; } - var job = await _jobs.Get(task.JobId); - if (job == null) { - _log.LogWarning("invalid {JobId}", task.JobId); + var task = await _tasks.GetByJobIdAndTaskId(hb.JobId, hb.TaskId); + if (task == null) { + _log.LogWarning("invalid {TaskId}", hb.TaskId); return; } + var newTask = task with { Heartbeat = DateTimeOffset.UtcNow }; var r = await _tasks.Replace(newTask); if (!r.IsOk) { diff --git a/src/ApiService/ApiService/Functions/Tasks.cs b/src/ApiService/ApiService/Functions/Tasks.cs index c648fa9305..b0505d2dbc 100644 --- a/src/ApiService/ApiService/Functions/Tasks.cs +++ b/src/ApiService/ApiService/Functions/Tasks.cs @@ -33,7 +33,7 @@ private async Async.Task Get(HttpRequestData req) { } if (request.OkV.TaskId is Guid taskId) { - var task = await _context.TaskOperations.GetByTaskId(taskId); + var task = await _context.TaskOperations.GetByTaskIdSlow(taskId); if (task == null) { return await _context.RequestHandling.NotOk( req, @@ -128,8 +128,7 @@ private async Async.Task Post(HttpRequestData req, FunctionCon if (cfg.PrereqTasks != null) { foreach (var taskId in cfg.PrereqTasks) { - var prereq = await _context.TaskOperations.GetByTaskId(taskId); - + var prereq = await _context.TaskOperations.GetByJobIdAndTaskId(cfg.JobId, taskId); if (prereq == null) { return await _context.RequestHandling.NotOk( req, @@ -163,7 +162,7 @@ private async Async.Task Delete(HttpRequestData req) { } - var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId); + var task = await _context.TaskOperations.GetByTaskIdSlow(request.OkV.TaskId); if (task == null) { return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "unable to find task" ), "task delete"); diff --git a/src/ApiService/ApiService/Functions/TimerRetention.cs b/src/ApiService/ApiService/Functions/TimerRetention.cs index 284dcdbfb3..b4d1df61f9 100644 --- a/src/ApiService/ApiService/Functions/TimerRetention.cs +++ b/src/ApiService/ApiService/Functions/TimerRetention.cs @@ -87,7 +87,7 @@ from container in task.Config.Containers } } else if (Guid.TryParse(q.Name, out queueId)) { //this is a task queue - var taskQueue = await _taskOps.GetByTaskId(queueId); + var taskQueue = await _taskOps.GetByTaskIdSlow(queueId); if (taskQueue is null) { // task does not exist. Ok to delete the task queue _log.LogInformation("Deleting {TaskQueueName} since task could not be found in Task table", q.Name); diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 54847d0e1b..ab41853a74 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -1,4 +1,5 @@ -using System.Reflection; +using System.ComponentModel.DataAnnotations; +using System.Reflection; using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; @@ -49,9 +50,9 @@ public enum JobResultType { public record HeartbeatData(HeartbeatType Type); public record TaskHeartbeatEntry( - Guid TaskId, - Guid? JobId, - Guid MachineId, + [property: Required] Guid TaskId, + [property: Required] Guid JobId, + [property: Required] Guid MachineId, HeartbeatData[] Data); public record JobResultData(JobResultType Type); @@ -99,6 +100,7 @@ public record NodeTasks ( [PartitionKey] Guid MachineId, [RowKey] Guid TaskId, + Guid? JobId, // not necessarily populated in old records NodeTaskState State = NodeTaskState.Init ) : StatefulEntityBase(State); diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index f3cc407b15..db63499d30 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -11,6 +11,7 @@ public record BaseRequest { public record CanScheduleRequest( [property: Required] Guid MachineId, + Guid? JobId, [property: Required] Guid TaskId ) : BaseRequest; @@ -63,9 +64,11 @@ public record WorkerEvent( ) : NodeEventBase; public record WorkerRunningEvent( + Guid? JobId, [property: Required] Guid TaskId); public record WorkerDoneEvent( + Guid? JobId, [property: Required] Guid TaskId, [property: Required] ExitStatus ExitStatus, [property: Required] string Stderr, @@ -81,8 +84,15 @@ public record NodeStateUpdate( [JsonConverter(typeof(SubclassConverter))] public abstract record NodeStateData; +public record NodeSettingUpData( + [property: Required] Guid JobId, + [property: Required] Guid TaskId); + +// TODO [future]: remove Tasks and make TaskData Required +// once all agents are compatible public record NodeSettingUpEventData( - [property: Required] List Tasks + List? Tasks, + List? TaskData ) : NodeStateData; public record NodeDoneEventData( diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index b6fa3c64d0..83d06a5df1 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -628,7 +628,12 @@ public IAsyncEnumerable SearchByPoolName(PoolName poolName) { public async Async.Task MarkTasksStoppedEarly(Node node, Error? error) { await foreach (var entry in _context.NodeTasksOperations.GetByMachineId(node.MachineId)) { - var task = await _context.TaskOperations.GetByTaskId(entry.TaskId); + var task = await + (entry.JobId.HasValue + ? _context.TaskOperations.GetByJobIdAndTaskId(entry.JobId.Value, entry.TaskId) + // old data might not have job ID: + : _context.TaskOperations.GetByTaskIdSlow(entry.TaskId)); + if (task is not null && !TaskStateHelper.ShuttingDown(task.State)) { var message = $"Node {node.MachineId} stopping while the task state is '{task.State}'"; if (error is not null) { @@ -701,10 +706,18 @@ public async Task> AddSshPublicKey(Node node, string publicK /// returns True on stopping the node and False if this doesn't stop the node private async Task StopIfComplete(Node node, bool done = false) { - var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync(); - var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds); + var tasks = _context.NodeTasksOperations.GetByMachineId(node.MachineId) + .SelectAwait(async node => { + if (node.JobId.HasValue) { + return await _context.TaskOperations.GetByJobIdAndTaskId(node.JobId.Value, node.TaskId); + } else { + // old existing records might not have jobId - fall back to slow lookup + return await _context.TaskOperations.GetByTaskIdSlow(node.TaskId); + } + }); + await foreach (var task in tasks) { - if (!TaskStateHelper.ShuttingDown(task.State)) { + if (task is not null && !TaskStateHelper.ShuttingDown(task.State)) { return false; } } diff --git a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs index cd3e289402..3d588fd2ec 100644 --- a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs @@ -45,7 +45,7 @@ public IAsyncEnumerable SearchExpired() { public async Async.Task GetVm(Repro repro, InstanceConfig config) { var taskOperations = _context.TaskOperations; var tags = config.VmTags; - var task = await taskOperations.GetByTaskId(repro.TaskId); + var task = await taskOperations.GetByTaskIdSlow(repro.TaskId); if (task == null) { throw new Exception($"previous existing task missing: {repro.TaskId}"); } @@ -242,7 +242,7 @@ public async Task BuildReproScript(Repro repro) { ); } - var task = await _context.TaskOperations.GetByTaskId(repro.TaskId); + var task = await _context.TaskOperations.GetByTaskIdSlow(repro.TaskId); if (task == null) { return OneFuzzResultVoid.Error( ErrorCode.VM_CREATE_FAILED, @@ -324,7 +324,7 @@ public async Async.Task SetError(Repro repro, Error result) { } public async Task GetSetupContainer(Repro repro) { - var task = await _context.TaskOperations.GetByTaskId(repro.TaskId); + var task = await _context.TaskOperations.GetByTaskIdSlow(repro.TaskId); return task?.Config?.Containers? .Where(container => container.Type == ContainerType.Setup) .FirstOrDefault()? @@ -337,7 +337,7 @@ public async Task> Create(ReproConfig config, UserInfo user return OneFuzzResult.Error(ErrorCode.UNABLE_TO_FIND, "unable to find report"); } - var task = await _context.TaskOperations.GetByTaskId(report.TaskId); + var task = await _context.TaskOperations.GetByTaskIdSlow(report.TaskId); if (task is null) { return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, "unable to find task"); } diff --git a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs index a7a4cd0ebb..5dbac7823f 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs @@ -5,9 +5,7 @@ namespace Microsoft.OneFuzz.Service; public interface ITaskOperations : IStatefulOrm { - Async.Task GetByTaskId(Guid taskId); - - IAsyncEnumerable GetByTaskIds(IEnumerable taskId); + Task GetByTaskIdSlow(Guid taskId); IAsyncEnumerable GetByJobId(Guid jobId); @@ -47,13 +45,8 @@ public TaskOperations(ILogger log, IMemoryCache cache, IOnefuzzC _cache = cache; } - public async Async.Task GetByTaskId(Guid taskId) { - return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync(); - } - - public IAsyncEnumerable GetByTaskIds(IEnumerable taskId) { - return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString()))); - } + public async Async.Task GetByTaskIdSlow(Guid taskId) + => await QueryAsync(filter: Query.RowKey(taskId.ToString())).FirstOrDefaultAsync(); public IAsyncEnumerable GetByJobId(Guid jobId) { return QueryAsync(Query.PartitionKey(jobId.ToString())); @@ -276,7 +269,7 @@ private async Async.Task OnStart(Task task) { public async Async.Task CheckPrereqTasks(Task task) { if (task.Config.PrereqTasks != null) { foreach (var taskId in task.Config.PrereqTasks) { - var t = await GetByTaskId(taskId); + var t = await GetByJobIdAndTaskId(task.JobId, taskId); // if a prereq task fails, then mark this task as failed if (t == null) { diff --git a/src/ApiService/IntegrationTests/AgentEventsTests.cs b/src/ApiService/IntegrationTests/AgentEventsTests.cs index 18c1485b96..73a93b8397 100644 --- a/src/ApiService/IntegrationTests/AgentEventsTests.cs +++ b/src/ApiService/IntegrationTests/AgentEventsTests.cs @@ -61,6 +61,7 @@ await Context.InsertAll( var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Done: new WorkerDoneEvent( + JobId: _jobId, TaskId: _taskId, ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: true), "stderr", @@ -88,6 +89,7 @@ await Context.InsertAll( var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Done: new WorkerDoneEvent( + JobId: _jobId, TaskId: _taskId, ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: false), // unsuccessful result "stderr", @@ -114,6 +116,7 @@ await Context.InsertAll( var data = new NodeStateEnvelope( MachineId: _machineId, Event: new WorkerEvent(Done: new WorkerDoneEvent( + JobId: _jobId, TaskId: _taskId, ExitStatus: new ExitStatus(0, 0, true), "stderr", @@ -137,7 +140,7 @@ await Context.InsertAll( var func = new AgentEvents(LoggerProvider.CreateLogger(), Context); var data = new NodeStateEnvelope( MachineId: _machineId, - Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); + Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId))); var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -153,7 +156,7 @@ await Context.InsertAll( var func = new AgentEvents(LoggerProvider.CreateLogger(), Context); var data = new NodeStateEnvelope( MachineId: _machineId, - Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); + Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId))); var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); @@ -170,7 +173,7 @@ await Context.InsertAll( var func = new AgentEvents(LoggerProvider.CreateLogger(), Context); var data = new NodeStateEnvelope( MachineId: _machineId, - Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId))); + Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId))); var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); Assert.Equal(HttpStatusCode.OK, result.StatusCode); @@ -199,7 +202,7 @@ await Async.Task.WhenAll( var taskEvent = await Context.TaskEventOperations.SearchAll().SingleAsync(); Assert.Equal(_taskId, taskEvent.TaskId); Assert.Equal(_machineId, taskEvent.MachineId); - Assert.Equal(new WorkerEvent(Running: new WorkerRunningEvent(_taskId)), taskEvent.EventData); + Assert.Equal(new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId)), taskEvent.EventData); })); } diff --git a/src/ApiService/IntegrationTests/_FunctionTestBase.cs b/src/ApiService/IntegrationTests/_FunctionTestBase.cs index 65a664c0c3..226409de5b 100644 --- a/src/ApiService/IntegrationTests/_FunctionTestBase.cs +++ b/src/ApiService/IntegrationTests/_FunctionTestBase.cs @@ -2,7 +2,6 @@ using System.IO; using System.Linq; using System.Net.Http; -using System.Threading.Tasks; using ApiService.OneFuzzLib.Orm; using Azure.Data.Tables; using Azure.Storage.Blobs; @@ -78,7 +77,7 @@ public async Task InitializeAsync() { public async Task DisposeAsync() { // clean up any tables & blobs that this test created var account = _storage.GetPrimaryAccount(StorageType.Config); - await ( + await Task.WhenAll( CleanupTables(await _storage.GetTableServiceClientForAccount(account)), CleanupBlobs(await _storage.GetBlobServiceClientForAccount(account))); } diff --git a/src/ApiService/Tests/RequestsTests.cs b/src/ApiService/Tests/RequestsTests.cs index e7f6f577c3..8e9cbf3af3 100644 --- a/src/ApiService/Tests/RequestsTests.cs +++ b/src/ApiService/Tests/RequestsTests.cs @@ -91,6 +91,7 @@ public void NodeEvent_WorkerEvent_Done() { ""event"": { ""worker_event"": { ""done"": { + ""job_id"": ""40a6e135-b6e0-4dc4-837d-0401db0061fb"", ""task_id"": ""00e1b131-e2a1-444d-8cc6-841e6cd48f93"", ""exit_status"": { ""code"": 0, @@ -114,6 +115,7 @@ public void NodeEvent_WorkerEvent_Running() { ""event"": { ""worker_event"": { ""running"": { + ""job_id"": ""a46bf12b-1837-48a6-b6a1-4e4b1c371c25"", ""task_id"": ""1763e113-02a0-4a3e-b477-92762f030d95"" } } @@ -152,17 +154,23 @@ public void NodeEvent_StateUpdate_Free() { [Fact] public void NodeEvent_StateUpdate_SettingUp() { - // generated with: onefuzz-agent debug node_event state_update '"setting_up"' + // generated with: onefuzz-agent debug node_event state_update setting-up AssertRoundtrips(@"{ ""event"": { ""state_update"": { ""state"": ""setting_up"", ""data"": { - ""tasks"": [ - ""163121e2-7df3-4567-9bd8-21b1653fac83"", - ""00604d49-b400-4877-8630-1d6ade31a61d"", - ""719a6316-98c4-4e77-9f3a-324f09505887"" + ""tasks"": null, + ""task_data"": [ + { + ""job_id"": ""b99d0d26-cb46-48af-8770-4768e1262d1c"", + ""task_id"": ""f78f8b2d-3ce1-466e-968b-c61fb9d49d58"" + }, + { + ""job_id"": ""dee926cf-a20a-4e6f-b806-324e64b07243"", + ""task_id"": ""61178115-34d8-43d2-8ee0-47f065bd7f74"" + } ] } } diff --git a/src/agent/onefuzz-agent/src/agent.rs b/src/agent/onefuzz-agent/src/agent.rs index e1b29f1a40..3f4ddbf33c 100644 --- a/src/agent/onefuzz-agent/src/agent.rs +++ b/src/agent/onefuzz-agent/src/agent.rs @@ -216,8 +216,17 @@ impl Agent { async fn setting_up(mut self, state: State, previous: NodeState) -> Result { info!("agent setting up"); - let tasks = state.work_set().task_ids(); - self.emit_state_update_if_changed(StateUpdateEvent::SettingUp { tasks }) + let tasks = state + .work_set() + .work_units + .iter() + .map(|w| SettingUpData { + job_id: w.job_id, + task_id: w.task_id, + }) + .collect(); + + self.emit_state_update_if_changed(StateUpdateEvent::SettingUp { task_data: tasks }) .await?; let scheduler = match state.finish(self.setup_runner.as_mut()).await? { diff --git a/src/agent/onefuzz-agent/src/agent/tests.rs b/src/agent/onefuzz-agent/src/agent/tests.rs index f0761e59f4..b1590c51f3 100644 --- a/src/agent/onefuzz-agent/src/agent/tests.rs +++ b/src/agent/onefuzz-agent/src/agent/tests.rs @@ -159,14 +159,19 @@ async fn test_emitted_state() { let expected_events: Vec = vec![ NodeEvent::StateUpdate(StateUpdateEvent::Free), NodeEvent::StateUpdate(StateUpdateEvent::SettingUp { - tasks: vec![Fixture.task_id()], + task_data: vec![SettingUpData { + task_id: Fixture.task_id(), + job_id: Fixture.job_id(), + }], }), NodeEvent::StateUpdate(StateUpdateEvent::Ready), NodeEvent::StateUpdate(StateUpdateEvent::Busy), NodeEvent::WorkerEvent(WorkerEvent::Running { + job_id: Fixture.job_id(), task_id: Fixture.task_id(), }), NodeEvent::WorkerEvent(WorkerEvent::Done { + job_id: Fixture.job_id(), task_id: Fixture.task_id(), exit_status: ExitStatus { code: Some(0), @@ -218,7 +223,10 @@ async fn test_emitted_state_failed_setup() { let expected_events: Vec = vec![ NodeEvent::StateUpdate(StateUpdateEvent::Free), NodeEvent::StateUpdate(StateUpdateEvent::SettingUp { - tasks: vec![Fixture.task_id()], + task_data: vec![SettingUpData { + task_id: Fixture.task_id(), + job_id: Fixture.job_id(), + }], }), NodeEvent::StateUpdate(StateUpdateEvent::Done { error: Some(String::from(error_message)), diff --git a/src/agent/onefuzz-agent/src/coordinator.rs b/src/agent/onefuzz-agent/src/coordinator.rs index e63f250b11..3efa44267d 100644 --- a/src/agent/onefuzz-agent/src/coordinator.rs +++ b/src/agent/onefuzz-agent/src/coordinator.rs @@ -86,13 +86,20 @@ impl From for NodeEvent { } } +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub struct SettingUpData { + pub task_id: Uuid, + pub job_id: Uuid, +} + #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case", tag = "state", content = "data")] pub enum StateUpdateEvent { Init, Free, SettingUp { - tasks: Vec, + task_data: Vec, }, Rebooting, Ready, @@ -125,6 +132,7 @@ pub enum TaskState { #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] pub struct CanScheduleRequest { machine_id: Uuid, + job_id: Uuid, task_id: Uuid, } @@ -265,8 +273,10 @@ impl Coordinator { // need to make sure that other the work units in the set have their states // updated if necessary. let task_id = work_set.work_units[0].task_id; + let job_id = work_set.work_units[0].job_id; let envelope = CanScheduleRequest { machine_id: self.registration.machine_id, + job_id, task_id, }; diff --git a/src/agent/onefuzz-agent/src/debug.rs b/src/agent/onefuzz-agent/src/debug.rs index 2e18b1e1d8..75970fe49c 100644 --- a/src/agent/onefuzz-agent/src/debug.rs +++ b/src/agent/onefuzz-agent/src/debug.rs @@ -59,8 +59,17 @@ fn debug_node_event_state_update(state: NodeState) -> Result<()> { NodeState::Init => StateUpdateEvent::Init, NodeState::Free => StateUpdateEvent::Free, NodeState::SettingUp => { - let tasks = vec![Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()]; - StateUpdateEvent::SettingUp { tasks } + let tasks = vec![ + SettingUpData { + task_id: Uuid::new_v4(), + job_id: Uuid::new_v4(), + }, + SettingUpData { + task_id: Uuid::new_v4(), + job_id: Uuid::new_v4(), + }, + ]; + StateUpdateEvent::SettingUp { task_data: tasks } } NodeState::Rebooting => StateUpdateEvent::Rebooting, NodeState::Ready => StateUpdateEvent::Ready, @@ -88,9 +97,10 @@ pub enum WorkerEventOpt { fn debug_node_event_worker_event(opt: WorkerEventOpt) -> Result<()> { let task_id = uuid::Uuid::new_v4(); + let job_id = uuid::Uuid::new_v4(); let event = match opt { - WorkerEventOpt::Running => WorkerEvent::Running { task_id }, + WorkerEventOpt::Running => WorkerEvent::Running { job_id, task_id }, WorkerEventOpt::Done { code, signal } => { let (code, signal) = match (code, signal) { // Default to ok exit. @@ -111,6 +121,7 @@ fn debug_node_event_worker_event(opt: WorkerEventOpt) -> Result<()> { exit_status, stderr, stdout, + job_id, task_id, } } diff --git a/src/agent/onefuzz-agent/src/main.rs b/src/agent/onefuzz-agent/src/main.rs index 5f34ea6db6..c3e7a284d9 100644 --- a/src/agent/onefuzz-agent/src/main.rs +++ b/src/agent/onefuzz-agent/src/main.rs @@ -267,6 +267,7 @@ async fn check_existing_worksets(coordinator: &mut coordinator::Coordinator) -> for unit in &work.work_units { let event = WorkerEvent::Done { + job_id: unit.job_id, task_id: unit.task_id, stdout: "".to_string(), stderr: failure.clone(), diff --git a/src/agent/onefuzz-agent/src/work.rs b/src/agent/onefuzz-agent/src/work.rs index d0222744a7..3edab28291 100644 --- a/src/agent/onefuzz-agent/src/work.rs +++ b/src/agent/onefuzz-agent/src/work.rs @@ -29,10 +29,6 @@ pub struct WorkSet { } impl WorkSet { - pub fn task_ids(&self) -> Vec { - self.work_units.iter().map(|w| w.task_id).collect() - } - pub fn context_path(machine_id: Uuid) -> Result { Ok(onefuzz::fs::onefuzz_root()?.join(format!("workset_context-{machine_id}.json"))) } diff --git a/src/agent/onefuzz-agent/src/worker.rs b/src/agent/onefuzz-agent/src/worker.rs index d05a95dacb..f88b9a25bf 100644 --- a/src/agent/onefuzz-agent/src/worker.rs +++ b/src/agent/onefuzz-agent/src/worker.rs @@ -35,9 +35,11 @@ const MAX_TAIL_LEN: usize = 40960; #[serde(rename_all = "snake_case")] pub enum WorkerEvent { Running { + job_id: JobId, task_id: TaskId, }, Done { + job_id: JobId, task_id: TaskId, exit_status: ExitStatus, stderr: String, @@ -83,6 +85,7 @@ impl Worker { Worker::Ready(state) => { let state = state.run(runner).await?; let event = WorkerEvent::Running { + job_id: state.work.job_id, task_id: state.work.task_id, }; events.push(event); @@ -95,6 +98,7 @@ impl Worker { exit_status: output.exit_status, stderr: output.stderr, stdout: output.stdout, + job_id: state.work.job_id, task_id: state.work.task_id, }; events.push(event); diff --git a/src/agent/onefuzz-agent/src/worker/tests.rs b/src/agent/onefuzz-agent/src/worker/tests.rs index 8b1dd803b2..2bfea89ae2 100644 --- a/src/agent/onefuzz-agent/src/worker/tests.rs +++ b/src/agent/onefuzz-agent/src/worker/tests.rs @@ -193,6 +193,7 @@ async fn test_running_wait_done() { #[tokio::test] async fn test_worker_ready_update() { let task_id = Fixture.work().task_id; + let job_id = Fixture.work().job_id; let state = State { ctx: Ready { @@ -208,7 +209,7 @@ async fn test_worker_ready_update() { let worker = worker.update(&mut events, &mut runner).await.unwrap(); assert!(matches!(worker, Worker::Running(..))); - assert_eq!(events, vec![WorkerEvent::Running { task_id }]); + assert_eq!(events, vec![WorkerEvent::Running { job_id, task_id }]); } #[tokio::test] @@ -258,6 +259,7 @@ async fn test_worker_running_update_done() { assert_eq!( events, vec![WorkerEvent::Done { + job_id: Fixture.work().job_id, task_id: Fixture.work().task_id, exit_status, stderr: "stderr".into(),