Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Improve slow Task queries (#3594)
Browse files Browse the repository at this point in the history
Querying Tasks only by TaskID is slow as it requires querying many partitions. Querying Task by JobID and TaskID is fast as it is a point lookup.

1. rename existing `GetTaskById` to `GetTaskByIdSlow`
2. include JobID in event data from agent (new properties are optional for back-compat)
3. include JobID in NodeTask entity (if this is missing upon lookup we will fall back to `GetTaskByIdSlow`)
  • Loading branch information
Porges authored Oct 24, 2023
1 parent ca3c503 commit 957f1fb
Show file tree
Hide file tree
Showing 22 changed files with 156 additions and 68 deletions.
7 changes: 6 additions & 1 deletion src/ApiService/ApiService/Functions/AgentCanSchedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ public async Async.Task<HttpResponseData> Run(
var allowed = canProcessNewWork.IsAllowed;
var reason = canProcessNewWork.Reason;

var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId);
var task = await
(canScheduleRequest.JobId.HasValue
? _context.TaskOperations.GetByJobIdAndTaskId(canScheduleRequest.JobId.Value, canScheduleRequest.TaskId)
// old agent, fall back
: _context.TaskOperations.GetByTaskIdSlow(canScheduleRequest.TaskId));

var workStopped = task == null || task.State.ShuttingDown();
if (!allowed) {
_log.LogInformation("Node cannot process new work {PoolName} {ScalesetId} - {MachineId} ", node.PoolName, node.ScalesetId, node.MachineId);
Expand Down
28 changes: 21 additions & 7 deletions src/ApiService/ApiService/Functions/AgentEvents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,23 @@ public async Async.Task<HttpResponseData> Run(
_log.LogInformation("node now available for work: {MachineId}", machineId);
} else if (ev.State == NodeState.SettingUp) {
if (ev.Data is NodeSettingUpEventData settingUpData) {
if (!settingUpData.Tasks.Any()) {
if (settingUpData.Tasks?.Any() == false || settingUpData.TaskData?.Any() == false) {
return Error.Create(ErrorCode.INVALID_REQUEST,
$"setup without tasks. machine_id: {machineId}"
);
}

foreach (var taskId in settingUpData.Tasks) {
var task = await _context.TaskOperations.GetByTaskId(taskId);
var tasks =
settingUpData.Tasks is not null
? settingUpData.Tasks.Select(t => (t, _context.TaskOperations.GetByTaskIdSlow(t)))
: settingUpData.TaskData!.Select(t => (t.TaskId, _context.TaskOperations.GetByJobIdAndTaskId(t.JobId, t.TaskId)));

foreach (var (id, taskTask) in tasks) {
var task = await taskTask;
if (task is null) {
return Error.Create(
ErrorCode.INVALID_REQUEST,
$"unable to find task: {taskId}");
$"unable to find task: {id}");
}

_log.LogInformation("node starting task. {MachineId} {JobId} {TaskId}", machineId, task.JobId, task.TaskId);
Expand All @@ -139,6 +144,7 @@ public async Async.Task<HttpResponseData> Run(
var nodeTask = new NodeTasks(
MachineId: machineId,
TaskId: task.TaskId,
JobId: task.JobId,
State: NodeTaskState.SettingUp);
var r = await _context.NodeTasksOperations.Replace(nodeTask);
if (!r.IsOk) {
Expand Down Expand Up @@ -183,7 +189,10 @@ public async Async.Task<HttpResponseData> Run(

private async Async.Task<Error?> OnWorkerEventRunning(Guid machineId, WorkerRunningEvent running) {
var (task, node) = await (
_context.TaskOperations.GetByTaskId(running.TaskId),
(running.JobId.HasValue
? _context.TaskOperations.GetByJobIdAndTaskId(running.JobId.Value, running.TaskId)
// old agent, fallback
: _context.TaskOperations.GetByTaskIdSlow(running.TaskId)),
_context.NodeOperations.GetByMachineId(machineId));

if (task is null) {
Expand All @@ -202,6 +211,7 @@ public async Async.Task<HttpResponseData> Run(
var nodeTask = new NodeTasks(
MachineId: machineId,
TaskId: running.TaskId,
JobId: running.JobId,
State: NodeTaskState.Running);
var r = await _context.NodeTasksOperations.Replace(nodeTask);
if (!r.IsOk) {
Expand Down Expand Up @@ -231,8 +241,12 @@ public async Async.Task<HttpResponseData> Run(
}

private async Async.Task<Error?> OnWorkerEventDone(Guid machineId, WorkerDoneEvent done) {

var (task, node) = await (
_context.TaskOperations.GetByTaskId(done.TaskId),
(done.JobId.HasValue
? _context.TaskOperations.GetByJobIdAndTaskId(done.JobId.Value, done.TaskId)
// old agent, fall back
: _context.TaskOperations.GetByTaskIdSlow(done.TaskId)),
_context.NodeOperations.GetByMachineId(machineId));

if (task is null) {
Expand Down Expand Up @@ -285,7 +299,7 @@ await _context.TaskOperations.MarkFailed(
}

if (!node.DebugKeepNode) {
var r = await _context.NodeTasksOperations.Delete(new NodeTasks(machineId, done.TaskId));
var r = await _context.NodeTasksOperations.Delete(new NodeTasks(machineId, done.TaskId, done.JobId));
if (!r.IsOk) {
_log.AddHttpStatus(r.ErrorV);
_log.LogError("failed to deleting node task {TaskId} for: {MachineId} since DebugKeepNode is false", done.TaskId, machineId);
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/QueueJobResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public async Async.Task Run([QueueTrigger("job-result", Connection = "AzureWebJo
_log.LogInformation("job result: {msg}", msg);
var jr = JsonSerializer.Deserialize<TaskJobResultEntry>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");

var task = await _tasks.GetByTaskId(jr.TaskId);
var task = await _tasks.GetByTaskIdSlow(jr.TaskId);
if (task == null) {
_log.LogWarning("invalid {TaskId}", jr.TaskId);
return;
Expand Down
13 changes: 7 additions & 6 deletions src/ApiService/ApiService/Functions/QueueTaskHeartbeat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ public async Async.Task Run([QueueTrigger("task-heartbeat", Connection = "AzureW
_log.LogInformation("heartbeat: {msg}", msg);
var hb = JsonSerializer.Deserialize<TaskHeartbeatEntry>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");

var task = await _tasks.GetByTaskId(hb.TaskId);
if (task == null) {
_log.LogWarning("invalid {TaskId}", hb.TaskId);
var job = await _jobs.Get(hb.JobId);
if (job == null) {
_log.LogWarning("invalid {JobId}", hb.JobId);
return;
}

var job = await _jobs.Get(task.JobId);
if (job == null) {
_log.LogWarning("invalid {JobId}", task.JobId);
var task = await _tasks.GetByJobIdAndTaskId(hb.JobId, hb.TaskId);
if (task == null) {
_log.LogWarning("invalid {TaskId}", hb.TaskId);
return;
}

var newTask = task with { Heartbeat = DateTimeOffset.UtcNow };
var r = await _tasks.Replace(newTask);
if (!r.IsOk) {
Expand Down
7 changes: 3 additions & 4 deletions src/ApiService/ApiService/Functions/Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
}

if (request.OkV.TaskId is Guid taskId) {
var task = await _context.TaskOperations.GetByTaskId(taskId);
var task = await _context.TaskOperations.GetByTaskIdSlow(taskId);
if (task == null) {
return await _context.RequestHandling.NotOk(
req,
Expand Down Expand Up @@ -128,8 +128,7 @@ private async Async.Task<HttpResponseData> Post(HttpRequestData req, FunctionCon

if (cfg.PrereqTasks != null) {
foreach (var taskId in cfg.PrereqTasks) {
var prereq = await _context.TaskOperations.GetByTaskId(taskId);

var prereq = await _context.TaskOperations.GetByJobIdAndTaskId(cfg.JobId, taskId);
if (prereq == null) {
return await _context.RequestHandling.NotOk(
req,
Expand Down Expand Up @@ -163,7 +162,7 @@ private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
}


var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId);
var task = await _context.TaskOperations.GetByTaskIdSlow(request.OkV.TaskId);
if (task == null) {
return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "unable to find task"
), "task delete");
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/TimerRetention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ from container in task.Config.Containers
}
} else if (Guid.TryParse(q.Name, out queueId)) {
//this is a task queue
var taskQueue = await _taskOps.GetByTaskId(queueId);
var taskQueue = await _taskOps.GetByTaskIdSlow(queueId);
if (taskQueue is null) {
// task does not exist. Ok to delete the task queue
_log.LogInformation("Deleting {TaskQueueName} since task could not be found in Task table", q.Name);
Expand Down
10 changes: 6 additions & 4 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Reflection;
using System.ComponentModel.DataAnnotations;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
Expand Down Expand Up @@ -49,9 +50,9 @@ public enum JobResultType {
public record HeartbeatData(HeartbeatType Type);

public record TaskHeartbeatEntry(
Guid TaskId,
Guid? JobId,
Guid MachineId,
[property: Required] Guid TaskId,
[property: Required] Guid JobId,
[property: Required] Guid MachineId,
HeartbeatData[] Data);

public record JobResultData(JobResultType Type);
Expand Down Expand Up @@ -99,6 +100,7 @@ public record NodeTasks
(
[PartitionKey] Guid MachineId,
[RowKey] Guid TaskId,
Guid? JobId, // not necessarily populated in old records
NodeTaskState State = NodeTaskState.Init
) : StatefulEntityBase<NodeTaskState>(State);

Expand Down
12 changes: 11 additions & 1 deletion src/ApiService/ApiService/OneFuzzTypes/Requests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public record BaseRequest {

public record CanScheduleRequest(
[property: Required] Guid MachineId,
Guid? JobId,
[property: Required] Guid TaskId
) : BaseRequest;

Expand Down Expand Up @@ -63,9 +64,11 @@ public record WorkerEvent(
) : NodeEventBase;

public record WorkerRunningEvent(
Guid? JobId,
[property: Required] Guid TaskId);

public record WorkerDoneEvent(
Guid? JobId,
[property: Required] Guid TaskId,
[property: Required] ExitStatus ExitStatus,
[property: Required] string Stderr,
Expand All @@ -81,8 +84,15 @@ public record NodeStateUpdate(
[JsonConverter(typeof(SubclassConverter<NodeStateData>))]
public abstract record NodeStateData;

public record NodeSettingUpData(
[property: Required] Guid JobId,
[property: Required] Guid TaskId);

// TODO [future]: remove Tasks and make TaskData Required
// once all agents are compatible
public record NodeSettingUpEventData(
[property: Required] List<Guid> Tasks
List<Guid>? Tasks,
List<NodeSettingUpData>? TaskData
) : NodeStateData;

public record NodeDoneEventData(
Expand Down
21 changes: 17 additions & 4 deletions src/ApiService/ApiService/onefuzzlib/NodeOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,12 @@ public IAsyncEnumerable<Node> SearchByPoolName(PoolName poolName) {

public async Async.Task MarkTasksStoppedEarly(Node node, Error? error) {
await foreach (var entry in _context.NodeTasksOperations.GetByMachineId(node.MachineId)) {
var task = await _context.TaskOperations.GetByTaskId(entry.TaskId);
var task = await
(entry.JobId.HasValue
? _context.TaskOperations.GetByJobIdAndTaskId(entry.JobId.Value, entry.TaskId)
// old data might not have job ID:
: _context.TaskOperations.GetByTaskIdSlow(entry.TaskId));

if (task is not null && !TaskStateHelper.ShuttingDown(task.State)) {
var message = $"Node {node.MachineId} stopping while the task state is '{task.State}'";
if (error is not null) {
Expand Down Expand Up @@ -701,10 +706,18 @@ public async Task<OneFuzzResult<bool>> AddSshPublicKey(Node node, string publicK

/// returns True on stopping the node and False if this doesn't stop the node
private async Task<bool> StopIfComplete(Node node, bool done = false) {
var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync();
var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds);
var tasks = _context.NodeTasksOperations.GetByMachineId(node.MachineId)
.SelectAwait(async node => {
if (node.JobId.HasValue) {
return await _context.TaskOperations.GetByJobIdAndTaskId(node.JobId.Value, node.TaskId);
} else {
// old existing records might not have jobId - fall back to slow lookup
return await _context.TaskOperations.GetByTaskIdSlow(node.TaskId);
}
});

await foreach (var task in tasks) {
if (!TaskStateHelper.ShuttingDown(task.State)) {
if (task is not null && !TaskStateHelper.ShuttingDown(task.State)) {
return false;
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/ApiService/ApiService/onefuzzlib/ReproOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public IAsyncEnumerable<Repro> SearchExpired() {
public async Async.Task<Vm> GetVm(Repro repro, InstanceConfig config) {
var taskOperations = _context.TaskOperations;
var tags = config.VmTags;
var task = await taskOperations.GetByTaskId(repro.TaskId);
var task = await taskOperations.GetByTaskIdSlow(repro.TaskId);
if (task == null) {
throw new Exception($"previous existing task missing: {repro.TaskId}");
}
Expand Down Expand Up @@ -242,7 +242,7 @@ public async Task<OneFuzzResultVoid> BuildReproScript(Repro repro) {
);
}

var task = await _context.TaskOperations.GetByTaskId(repro.TaskId);
var task = await _context.TaskOperations.GetByTaskIdSlow(repro.TaskId);
if (task == null) {
return OneFuzzResultVoid.Error(
ErrorCode.VM_CREATE_FAILED,
Expand Down Expand Up @@ -324,7 +324,7 @@ public async Async.Task<Repro> SetError(Repro repro, Error result) {
}

public async Task<Container?> GetSetupContainer(Repro repro) {
var task = await _context.TaskOperations.GetByTaskId(repro.TaskId);
var task = await _context.TaskOperations.GetByTaskIdSlow(repro.TaskId);
return task?.Config?.Containers?
.Where(container => container.Type == ContainerType.Setup)
.FirstOrDefault()?
Expand All @@ -337,7 +337,7 @@ public async Task<OneFuzzResult<Repro>> Create(ReproConfig config, UserInfo user
return OneFuzzResult<Repro>.Error(ErrorCode.UNABLE_TO_FIND, "unable to find report");
}

var task = await _context.TaskOperations.GetByTaskId(report.TaskId);
var task = await _context.TaskOperations.GetByTaskIdSlow(report.TaskId);
if (task is null) {
return OneFuzzResult<Repro>.Error(ErrorCode.INVALID_REQUEST, "unable to find task");
}
Expand Down
15 changes: 4 additions & 11 deletions src/ApiService/ApiService/onefuzzlib/TaskOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
namespace Microsoft.OneFuzz.Service;

public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
Async.Task<Task?> GetByTaskId(Guid taskId);

IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId);
Task<Task?> GetByTaskIdSlow(Guid taskId);

IAsyncEnumerable<Task> GetByJobId(Guid jobId);

Expand Down Expand Up @@ -47,13 +45,8 @@ public TaskOperations(ILogger<TaskOperations> log, IMemoryCache cache, IOnefuzzC
_cache = cache;
}

public async Async.Task<Task?> GetByTaskId(Guid taskId) {
return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync();
}

public IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId) {
return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString())));
}
public async Async.Task<Task?> GetByTaskIdSlow(Guid taskId)
=> await QueryAsync(filter: Query.RowKey(taskId.ToString())).FirstOrDefaultAsync();

public IAsyncEnumerable<Task> GetByJobId(Guid jobId) {
return QueryAsync(Query.PartitionKey(jobId.ToString()));
Expand Down Expand Up @@ -276,7 +269,7 @@ private async Async.Task<Task> OnStart(Task task) {
public async Async.Task<bool> CheckPrereqTasks(Task task) {
if (task.Config.PrereqTasks != null) {
foreach (var taskId in task.Config.PrereqTasks) {
var t = await GetByTaskId(taskId);
var t = await GetByJobIdAndTaskId(task.JobId, taskId);

// if a prereq task fails, then mark this task as failed
if (t == null) {
Expand Down
11 changes: 7 additions & 4 deletions src/ApiService/IntegrationTests/AgentEventsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ await Context.InsertAll(
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Done: new WorkerDoneEvent(
JobId: _jobId,
TaskId: _taskId,
ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: true),
"stderr",
Expand Down Expand Up @@ -88,6 +89,7 @@ await Context.InsertAll(
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Done: new WorkerDoneEvent(
JobId: _jobId,
TaskId: _taskId,
ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: false), // unsuccessful result
"stderr",
Expand All @@ -114,6 +116,7 @@ await Context.InsertAll(
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Done: new WorkerDoneEvent(
JobId: _jobId,
TaskId: _taskId,
ExitStatus: new ExitStatus(0, 0, true),
"stderr",
Expand All @@ -137,7 +140,7 @@ await Context.InsertAll(
var func = new AgentEvents(LoggerProvider.CreateLogger<AgentEvents>(), Context);
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId)));
Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId)));

var result = await func.Run(TestHttpRequestData.FromJson("POST", data));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Expand All @@ -153,7 +156,7 @@ await Context.InsertAll(
var func = new AgentEvents(LoggerProvider.CreateLogger<AgentEvents>(), Context);
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId)));
Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId)));

var result = await func.Run(TestHttpRequestData.FromJson("POST", data));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Expand All @@ -170,7 +173,7 @@ await Context.InsertAll(
var func = new AgentEvents(LoggerProvider.CreateLogger<AgentEvents>(), Context);
var data = new NodeStateEnvelope(
MachineId: _machineId,
Event: new WorkerEvent(Running: new WorkerRunningEvent(_taskId)));
Event: new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId)));

var result = await func.Run(TestHttpRequestData.FromJson("POST", data));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
Expand Down Expand Up @@ -199,7 +202,7 @@ await Async.Task.WhenAll(
var taskEvent = await Context.TaskEventOperations.SearchAll().SingleAsync();
Assert.Equal(_taskId, taskEvent.TaskId);
Assert.Equal(_machineId, taskEvent.MachineId);
Assert.Equal(new WorkerEvent(Running: new WorkerRunningEvent(_taskId)), taskEvent.EventData);
Assert.Equal(new WorkerEvent(Running: new WorkerRunningEvent(_jobId, _taskId)), taskEvent.EventData);
}));
}

Expand Down
Loading

0 comments on commit 957f1fb

Please sign in to comment.