diff --git a/config/armada/config.yaml b/config/armada/config.yaml index 265c2d80822..90b8af17421 100644 --- a/config/armada/config.yaml +++ b/config/armada/config.yaml @@ -34,6 +34,7 @@ scheduling: preemption: nodeEvictionProbability: 1.0 nodeOversubscriptionEvictionProbability: 1.0 + protectedFractionOfFairShare: 1.0 setNodeIdSelector: true nodeIdLabel: kubernetes.io/hostname setNodeName: false diff --git a/config/executor/config.yaml b/config/executor/config.yaml index c81beb2773d..44cb869cc91 100644 --- a/config/executor/config.yaml +++ b/config/executor/config.yaml @@ -59,6 +59,9 @@ kubernetes: fatalPodSubmissionErrors: - "admission webhook" - "namespaces \".*\" not found" + stateChecks: + deadlineForSubmittedPodConsideredMissing: 15m + deadlineForActivePodConsideredMissing: 5m pendingPodChecks: deadlineForUpdates: 10m deadlineForNodeAssignment: 5m diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index dc93ae7ec4d..cf9db749e82 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -209,6 +209,8 @@ type PreemptionConfig struct { // the probability of evicting jobs on oversubscribed nodes, i.e., // nodes on which the total resource requests are greater than the available resources. NodeOversubscriptionEvictionProbability float64 + // Only queues allocated more than this fraction of their fair share are considered for preemption. + ProtectedFractionOfFairShare float64 // If true, the Armada scheduler will add to scheduled pods a node selector // NodeIdLabel: . // If true, NodeIdLabel must be non-empty. diff --git a/internal/armada/server/lease.go b/internal/armada/server/lease.go index 486a8bec099..456ba84b815 100644 --- a/internal/armada/server/lease.go +++ b/internal/armada/server/lease.go @@ -353,7 +353,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL // Group gangs. for _, job := range jobs { - gangId, _, isGangJob, err := scheduler.GangIdAndCardinalityFromLegacySchedulerJob(job, q.schedulingConfig.Preemption.PriorityClasses) + gangId, _, isGangJob, err := scheduler.GangIdAndCardinalityFromLegacySchedulerJob(job) if err != nil { return nil, err } @@ -469,7 +469,11 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL schedulerobjects.ResourceList{Resources: totalCapacity}, ) for queue, priorityFactor := range priorityFactorByQueue { - if err := sctx.AddQueueSchedulingContext(queue, priorityFactor, allocatedByQueueAndPriorityClassForPool[queue]); err != nil { + var weight float64 = 1 + if priorityFactor > 0 { + weight = 1 / priorityFactor + } + if err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByQueueAndPriorityClassForPool[queue]); err != nil { return nil, err } } @@ -484,6 +488,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL constraints, q.schedulingConfig.Preemption.NodeEvictionProbability, q.schedulingConfig.Preemption.NodeOversubscriptionEvictionProbability, + q.schedulingConfig.Preemption.ProtectedFractionOfFairShare, &SchedulerJobRepositoryAdapter{ r: q.jobRepository, }, diff --git a/internal/common/database/lookout/jobstates.go b/internal/common/database/lookout/jobstates.go index 9ba1ce54f31..20ea463dde5 100644 --- a/internal/common/database/lookout/jobstates.go +++ b/internal/common/database/lookout/jobstates.go @@ -51,6 +51,18 @@ const ( ) var ( + // JobStates is an ordered list of states + JobStates = []JobState{ + JobQueued, + JobLeased, + JobPending, + JobRunning, + JobSucceeded, + JobFailed, + JobCancelled, + JobPreempted, + } + JobStateMap = map[int]JobState{ JobLeasedOrdinal: JobLeased, JobQueuedOrdinal: JobQueued, diff --git a/internal/executor/application.go b/internal/executor/application.go index 47c9f02dd33..d7ad549c47c 100644 --- a/internal/executor/application.go +++ b/internal/executor/application.go @@ -189,9 +189,11 @@ func setupExecutorApiComponents( jobRunState, submitter, etcdHealthMonitor) - podIssueService := service.NewPodIssueService( + podIssueService := service.NewIssueHandler( + jobRunState, clusterContext, eventReporter, + config.Kubernetes.StateChecks, pendingPodChecker, config.Kubernetes.StuckTerminatingPodExpiry) diff --git a/internal/executor/configuration/types.go b/internal/executor/configuration/types.go index 04f7ccfa482..4798f29710a 100644 --- a/internal/executor/configuration/types.go +++ b/internal/executor/configuration/types.go @@ -26,6 +26,17 @@ type PodDefaults struct { Ingress *IngressConfiguration } +type StateChecksConfiguration struct { + // Once a pod is submitted to kubernetes, this is how long we'll wait for it to appear in the kubernetes informer state + // If the pod hasn't appeared after this duration, it is considered missing + DeadlineForSubmittedPodConsideredMissing time.Duration + // Once the executor has seen a pod appear on the cluster, it considers that run Active + // If we get into a state where there is no longer a pod backing that Active run, this is how long we'll wait before we consider the pod missing + // The most likely cause of this is actually a bug in the executors processing of the kubernetes state + // However without it - we can have runs get indefinitely stuck as Active with no backing pod + DeadlineForActivePodConsideredMissing time.Duration +} + type IngressConfiguration struct { HostnameSuffix string CertNameSuffix string @@ -54,6 +65,7 @@ type KubernetesConfiguration struct { MaxTerminatedPods int MinimumJobSize armadaresource.ComputeResources PodDefaults *PodDefaults + StateChecks StateChecksConfiguration PendingPodChecks *podchecks.Checks FatalPodSubmissionErrors []string // Minimum amount of resources marked as allocated to non-Armada pods on each node. diff --git a/internal/executor/job/job_run_state_store.go b/internal/executor/job/job_run_state_store.go index 421d650e7d8..2752ac5bfb3 100644 --- a/internal/executor/job/job_run_state_store.go +++ b/internal/executor/job/job_run_state_store.go @@ -51,12 +51,14 @@ func NewJobRunStateStore(clusterContext context.ClusterContext) *JobRunStateStor return } - stateStore.reportRunActive(pod) + if !util.IsPodFinishedAndReported(pod) { + stateStore.reportRunActive(pod) + } }, }) // On start up, make sure our state matches current k8s state - err := stateStore.reconcileStateWithKubernetes() + err := stateStore.initialiseStateFromKubernetes() if err != nil { panic(err) } @@ -75,7 +77,7 @@ func NewJobRunStateStoreWithInitialState(initialJobRuns []*RunState) *JobRunStat return stateStore } -func (stateStore *JobRunStateStore) reconcileStateWithKubernetes() error { +func (stateStore *JobRunStateStore) initialiseStateFromKubernetes() error { pods, err := stateStore.clusterContext.GetAllPods() if err != nil { return err @@ -84,7 +86,9 @@ func (stateStore *JobRunStateStore) reconcileStateWithKubernetes() error { return !util.IsLegacyManagedPod(pod) }) for _, pod := range pods { - stateStore.reportRunActive(pod) + if !util.IsPodFinishedAndReported(pod) { + stateStore.reportRunActive(pod) + } } return nil diff --git a/internal/executor/job/job_run_state_store_test.go b/internal/executor/job/job_run_state_store_test.go index 9092ffa90d9..da29c9a4f7f 100644 --- a/internal/executor/job/job_run_state_store_test.go +++ b/internal/executor/job/job_run_state_store_test.go @@ -4,6 +4,7 @@ import ( "fmt" "sort" "testing" + "time" "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" @@ -23,7 +24,7 @@ var defaultRunInfoMeta = &RunMeta{ JobSet: "job-set-1", } -func TestOnStartUp_ReconcilesWithKubernetes(t *testing.T) { +func TestOnStartUp_ReconcilesWithKubernetes_ActivePod(t *testing.T) { existingPod := createPod() jobRunStateManager, _ := setup(t, []*v1.Pod{existingPod}) @@ -38,6 +39,18 @@ func TestOnStartUp_ReconcilesWithKubernetes(t *testing.T) { assert.Equal(t, allKnownJobRuns[0].Phase, Active) } +func TestOnStartUp_ReconcilesWithKubernetes_IgnoresDonePods(t *testing.T) { + donePod := createPod() + donePod.Status.Phase = v1.PodSucceeded + donePod.Annotations[domain.JobDoneAnnotation] = "true" + donePod.Annotations[string(donePod.Status.Phase)] = fmt.Sprintf("%s", time.Now()) + + jobRunStateManager, _ := setup(t, []*v1.Pod{donePod}) + allKnownJobRuns := jobRunStateManager.GetAll() + + assert.Len(t, allKnownJobRuns, 0) +} + func TestReportRunLeased(t *testing.T) { job := &SubmitJob{ Meta: SubmitJobMeta{ diff --git a/internal/executor/reporter/job_event_reporter.go b/internal/executor/reporter/job_event_reporter.go index 1ae228ed4c3..88c1091e002 100644 --- a/internal/executor/reporter/job_event_reporter.go +++ b/internal/executor/reporter/job_event_reporter.go @@ -169,6 +169,7 @@ func (eventReporter *JobEventReporter) reportStatusUpdate(old *v1.Pod, new *v1.P // Don't report status change for pods Armada is deleting // This prevents reporting JobFailed when we delete a pod - for example due to cancellation if util.IsMarkedForDeletion(new) { + log.Infof("not sending event to report pod %s moving into phase %s as pod is marked for deletion", new.Name, new.Status.Phase) return } eventReporter.reportCurrentStatus(new) diff --git a/internal/executor/service/pod_issue_handler.go b/internal/executor/service/pod_issue_handler.go index 8d6b6d99200..8d15d26bc84 100644 --- a/internal/executor/service/pod_issue_handler.go +++ b/internal/executor/service/pod_issue_handler.go @@ -8,67 +8,89 @@ import ( log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/clock" "k8s.io/client-go/tools/cache" + "github.com/armadaproject/armada/internal/executor/configuration" executorContext "github.com/armadaproject/armada/internal/executor/context" + "github.com/armadaproject/armada/internal/executor/job" "github.com/armadaproject/armada/internal/executor/podchecks" "github.com/armadaproject/armada/internal/executor/reporter" "github.com/armadaproject/armada/internal/executor/util" "github.com/armadaproject/armada/pkg/api" ) -type IssueType int +type podIssueType int const ( - UnableToSchedule IssueType = iota + UnableToSchedule podIssueType = iota StuckStartingUp StuckTerminating ExternallyDeleted + ErrorDuringIssueHandling ) type podIssue struct { // A copy of the pod when an issue was detected OriginalPodState *v1.Pod - JobId string - RunId string Message string Retryable bool - Reported bool DeletionRequested bool - Type IssueType + Type podIssueType Cause api.Cause } +type reconciliationIssue struct { + InitialDetectionTime time.Time + OriginalRunState *job.RunState +} + type issue struct { CurrentPodState *v1.Pod - Issue *podIssue + RunIssue *runIssue +} + +type runIssue struct { + JobId string + RunId string + PodIssue *podIssue + ReconciliationIssue *reconciliationIssue + Reported bool } -type PodIssueService struct { +type IssueHandler struct { clusterContext executorContext.ClusterContext eventReporter reporter.EventReporter pendingPodChecker podchecks.PodChecker + stateChecksConfig configuration.StateChecksConfiguration stuckTerminatingPodExpiry time.Duration // JobRunId -> PodIssue - knownPodIssues map[string]*podIssue + knownPodIssues map[string]*runIssue podIssueMutex sync.Mutex + jobRunState job.RunStateStore + clock clock.Clock } -func NewPodIssueService( +func NewIssueHandler( + jobRunState job.RunStateStore, clusterContext executorContext.ClusterContext, eventReporter reporter.EventReporter, + stateChecksConfig configuration.StateChecksConfiguration, pendingPodChecker podchecks.PodChecker, stuckTerminatingPodExpiry time.Duration, -) *PodIssueService { - podIssueService := &PodIssueService{ +) *IssueHandler { + issueHandler := &IssueHandler{ + jobRunState: jobRunState, clusterContext: clusterContext, eventReporter: eventReporter, pendingPodChecker: pendingPodChecker, + stateChecksConfig: stateChecksConfig, stuckTerminatingPodExpiry: stuckTerminatingPodExpiry, - knownPodIssues: map[string]*podIssue{}, + knownPodIssues: map[string]*runIssue{}, podIssueMutex: sync.Mutex{}, + clock: clock.RealClock{}, } clusterContext.AddPodEventHandler(cache.ResourceEventHandlerFuncs{ @@ -78,20 +100,20 @@ func NewPodIssueService( log.Errorf("Failed to process pod event due to it being an unexpected type. Failed to process %+v", obj) return } - podIssueService.handleDeletedPod(pod) + issueHandler.handleDeletedPod(pod) }, }) - return podIssueService + return issueHandler } -func (p *PodIssueService) registerIssue(issue *podIssue) { +func (p *IssueHandler) registerIssue(issue *runIssue) { p.podIssueMutex.Lock() defer p.podIssueMutex.Unlock() runId := issue.RunId if runId == "" { - log.Warnf("Not registering an issue for job %s (%s) as run id was empty", issue.JobId, issue.OriginalPodState.Name) + log.Warnf("Not registering an issue for job %s as run id was empty", issue.JobId) return } _, exists := p.knownPodIssues[issue.RunId] @@ -102,18 +124,18 @@ func (p *PodIssueService) registerIssue(issue *podIssue) { } } -func (p *PodIssueService) markIssuesResolved(issue *podIssue) { +func (p *IssueHandler) markIssuesResolved(issue *runIssue) { p.podIssueMutex.Lock() defer p.podIssueMutex.Unlock() delete(p.knownPodIssues, issue.RunId) } -func (p *PodIssueService) markIssueReported(issue *podIssue) { +func (p *IssueHandler) markIssueReported(issue *runIssue) { issue.Reported = true } -func (p *PodIssueService) HandlePodIssues() { +func (p *IssueHandler) HandlePodIssues() { managedPods, err := p.clusterContext.GetBatchPods() if err != nil { log.WithError(err).Errorf("unable to handle pod issus as failed to load pods") @@ -122,26 +144,29 @@ func (p *PodIssueService) HandlePodIssues() { return !util.IsLegacyManagedPod(pod) }) p.detectPodIssues(managedPods) + p.detectReconciliationIssues(managedPods) ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) defer cancel() - p.handleKnownPodIssues(ctx, managedPods) + p.handleKnownIssues(ctx, managedPods) } -func (p *PodIssueService) detectPodIssues(allManagedPods []*v1.Pod) { +func (p *IssueHandler) detectPodIssues(allManagedPods []*v1.Pod) { for _, pod := range allManagedPods { - if pod.DeletionTimestamp != nil && pod.DeletionTimestamp.Add(p.stuckTerminatingPodExpiry).Before(time.Now()) { + if pod.DeletionTimestamp != nil && pod.DeletionTimestamp.Add(p.stuckTerminatingPodExpiry).Before(p.clock.Now()) { // pod is stuck in terminating phase, this sometimes happen on node failure // it is safer to produce failed event than retrying as the job might have run already issue := &podIssue{ OriginalPodState: pod.DeepCopy(), - JobId: util.ExtractJobId(pod), - RunId: util.ExtractJobRunId(pod), Message: "pod stuck in terminating phase, this might be due to platform problems", Retryable: false, Type: StuckTerminating, } - p.registerIssue(issue) + p.registerIssue(&runIssue{ + JobId: util.ExtractJobId(pod), + RunId: util.ExtractJobRunId(pod), + PodIssue: issue, + }) } else if pod.Status.Phase == v1.PodUnknown || pod.Status.Phase == v1.PodPending { podEvents, err := p.clusterContext.GetPodEvents(pod) @@ -155,7 +180,7 @@ func (p *PodIssueService) detectPodIssues(allManagedPods []*v1.Pod) { continue } - action, cause, podCheckMessage := p.pendingPodChecker.GetAction(pod, podEvents, time.Now().Sub(lastStateChange)) + action, cause, podCheckMessage := p.pendingPodChecker.GetAction(pod, podEvents, p.clock.Now().Sub(lastStateChange)) if action != podchecks.ActionWait { retryable := action == podchecks.ActionRetry @@ -169,25 +194,27 @@ func (p *PodIssueService) detectPodIssues(allManagedPods []*v1.Pod) { issue := &podIssue{ OriginalPodState: pod.DeepCopy(), - JobId: util.ExtractJobId(pod), - RunId: util.ExtractJobRunId(pod), Message: message, Retryable: retryable, Type: podIssueType, } - p.registerIssue(issue) + p.registerIssue(&runIssue{ + JobId: util.ExtractJobId(pod), + RunId: util.ExtractJobRunId(pod), + PodIssue: issue, + }) } } } } -func (p *PodIssueService) handleKnownPodIssues(ctx context.Context, allManagedPods []*v1.Pod) { +func (p *IssueHandler) handleKnownIssues(ctx context.Context, allManagedPods []*v1.Pod) { // Make issues from pods + issues issues := createIssues(allManagedPods, p.knownPodIssues) - util.ProcessItemsWithThreadPool(ctx, 20, issues, p.handlePodIssue) + util.ProcessItemsWithThreadPool(ctx, 20, issues, p.handleRunIssue) } -func createIssues(managedPods []*v1.Pod, podIssues map[string]*podIssue) []*issue { +func createIssues(managedPods []*v1.Pod, runIssues map[string]*runIssue) []*issue { podsByRunId := make(map[string]*v1.Pod, len(managedPods)) for _, pod := range managedPods { @@ -199,25 +226,40 @@ func createIssues(managedPods []*v1.Pod, podIssues map[string]*podIssue) []*issu } } - result := make([]*issue, 0, len(podIssues)) + result := make([]*issue, 0, len(runIssues)) - for _, podIssue := range podIssues { - relatedPod := podsByRunId[podIssue.RunId] - result = append(result, &issue{CurrentPodState: relatedPod, Issue: podIssue}) + for _, runIssue := range runIssues { + relatedPod := podsByRunId[runIssue.RunId] + result = append(result, &issue{CurrentPodState: relatedPod, RunIssue: runIssue}) } return result } -func (p *PodIssueService) handlePodIssue(issue *issue) { +func (p *IssueHandler) handleRunIssue(issue *issue) { + if issue == nil || issue.RunIssue == nil { + log.Warnf("issue found with missing issue details") + return + } + if issue.RunIssue.PodIssue != nil { + p.handlePodIssue(issue) + } else if issue.RunIssue.ReconciliationIssue != nil { + p.handleReconciliationIssue(issue) + } else { + log.Warnf("issue found with no issue details set for job %s run %s", issue.RunIssue.JobId, issue.RunIssue.RunId) + p.markIssuesResolved(issue.RunIssue) + } +} + +func (p *IssueHandler) handlePodIssue(issue *issue) { hasSelfResolved := hasPodIssueSelfResolved(issue) if hasSelfResolved { - log.Infof("Issue for job %s run %s has self resolved", issue.Issue.JobId, issue.Issue.RunId) - p.markIssuesResolved(issue.Issue) + log.Infof("Issue for job %s run %s has self resolved", issue.RunIssue.JobId, issue.RunIssue.RunId) + p.markIssuesResolved(issue.RunIssue) return } - if issue.Issue.Retryable { + if issue.RunIssue.PodIssue.Retryable { p.handleRetryableJobIssue(issue) } else { p.handleNonRetryableJobIssue(issue) @@ -229,32 +271,32 @@ func (p *PodIssueService) handlePodIssue(issue *issue) { // - Report JobFailedEvent // // Once that is done we are free to cleanup the pod -func (p *PodIssueService) handleNonRetryableJobIssue(issue *issue) { - if !issue.Issue.Reported { - log.Infof("Non-retryable issue detected for job %s run %s - %s", issue.Issue.JobId, issue.Issue.RunId, issue.Issue.Message) - message := issue.Issue.Message +func (p *IssueHandler) handleNonRetryableJobIssue(issue *issue) { + if !issue.RunIssue.Reported { + log.Infof("Non-retryable issue detected for job %s run %s - %s", issue.RunIssue.JobId, issue.RunIssue.RunId, issue.RunIssue.PodIssue.Message) + message := issue.RunIssue.PodIssue.Message events := make([]reporter.EventMessage, 0, 2) - if issue.Issue.Type == StuckStartingUp || issue.Issue.Type == UnableToSchedule { - unableToScheduleEvent := reporter.CreateJobUnableToScheduleEvent(issue.Issue.OriginalPodState, message, p.clusterContext.GetClusterId()) - events = append(events, reporter.EventMessage{Event: unableToScheduleEvent, JobRunId: issue.Issue.RunId}) + if issue.RunIssue.PodIssue.Type == StuckStartingUp || issue.RunIssue.PodIssue.Type == UnableToSchedule { + unableToScheduleEvent := reporter.CreateJobUnableToScheduleEvent(issue.RunIssue.PodIssue.OriginalPodState, message, p.clusterContext.GetClusterId()) + events = append(events, reporter.EventMessage{Event: unableToScheduleEvent, JobRunId: issue.RunIssue.RunId}) } - failedEvent := reporter.CreateSimpleJobFailedEvent(issue.Issue.OriginalPodState, message, p.clusterContext.GetClusterId(), issue.Issue.Cause) - events = append(events, reporter.EventMessage{Event: failedEvent, JobRunId: issue.Issue.RunId}) + failedEvent := reporter.CreateSimpleJobFailedEvent(issue.RunIssue.PodIssue.OriginalPodState, message, p.clusterContext.GetClusterId(), issue.RunIssue.PodIssue.Cause) + events = append(events, reporter.EventMessage{Event: failedEvent, JobRunId: issue.RunIssue.RunId}) err := p.eventReporter.Report(events) if err != nil { - log.Errorf("Failed to report failed event for job %s because %s", issue.Issue.JobId, err) + log.Errorf("Failed to report failed event for job %s because %s", issue.RunIssue.JobId, err) return } - p.markIssueReported(issue.Issue) + p.markIssueReported(issue.RunIssue) } if issue.CurrentPodState != nil { p.clusterContext.DeletePods([]*v1.Pod{issue.CurrentPodState}) - issue.Issue.DeletionRequested = true + issue.RunIssue.PodIssue.DeletionRequested = true } else { - p.markIssuesResolved(issue.Issue) + p.markIssuesResolved(issue.RunIssue) } } @@ -262,76 +304,82 @@ func (p *PodIssueService) handleNonRetryableJobIssue(issue *issue) { // - Report JobUnableToScheduleEvent // - Report JobReturnLeaseEvent // -// Special consideration must be taken that most of these pods are somewhat "stuck" in pending. -// So can transition to Running/Completed/Failed in the middle of this -// We must not return the lease if the pod state changes - as likely it has become "unstuck" -func (p *PodIssueService) handleRetryableJobIssue(issue *issue) { - if !issue.Issue.Reported { - log.Infof("Retryable issue detected for job %s run %s - %s", issue.Issue.JobId, issue.Issue.RunId, issue.Issue.Message) - if issue.Issue.Type == StuckStartingUp || issue.Issue.Type == UnableToSchedule { - event := reporter.CreateJobUnableToScheduleEvent(issue.Issue.OriginalPodState, issue.Issue.Message, p.clusterContext.GetClusterId()) - err := p.eventReporter.Report([]reporter.EventMessage{{Event: event, JobRunId: issue.Issue.RunId}}) +// If the pod becomes Running/Completed/Failed in the middle of being deleted - swap this issue to a nonRetryableIssue where it will be Failed +func (p *IssueHandler) handleRetryableJobIssue(issue *issue) { + if !issue.RunIssue.Reported { + log.Infof("Retryable issue detected for job %s run %s - %s", issue.RunIssue.JobId, issue.RunIssue.RunId, issue.RunIssue.PodIssue.Message) + if issue.RunIssue.PodIssue.Type == StuckStartingUp || issue.RunIssue.PodIssue.Type == UnableToSchedule { + event := reporter.CreateJobUnableToScheduleEvent(issue.RunIssue.PodIssue.OriginalPodState, issue.RunIssue.PodIssue.Message, p.clusterContext.GetClusterId()) + err := p.eventReporter.Report([]reporter.EventMessage{{Event: event, JobRunId: issue.RunIssue.RunId}}) if err != nil { log.Errorf("Failure to report stuck pod event %+v because %s", event, err) return } } - p.markIssueReported(issue.Issue) + p.markIssueReported(issue.RunIssue) } if issue.CurrentPodState != nil { - // TODO consider moving this to a synchronous call - but long termination periods would need to be handled + if issue.CurrentPodState.Status.Phase != v1.PodPending { + p.markIssuesResolved(issue.RunIssue) + if issue.RunIssue.PodIssue.DeletionRequested { + p.registerIssue(&runIssue{ + JobId: issue.RunIssue.JobId, + RunId: issue.RunIssue.RunId, + PodIssue: &podIssue{ + OriginalPodState: issue.RunIssue.PodIssue.OriginalPodState, + Message: "Pod unexpectedly started up after delete was called", + Retryable: false, + DeletionRequested: false, + Type: ErrorDuringIssueHandling, + Cause: api.Cause_Error, + }, + }) + } + return + } + err := p.clusterContext.DeletePodWithCondition(issue.CurrentPodState, func(pod *v1.Pod) bool { return pod.Status.Phase == v1.PodPending }, true) if err != nil { - log.Errorf("Failed to delete pod of running job %s because %s", issue.Issue.JobId, err) + log.Errorf("Failed to delete pod of running job %s because %s", issue.RunIssue.JobId, err) return } else { - issue.Issue.DeletionRequested = true + issue.RunIssue.PodIssue.DeletionRequested = true } } else { // TODO // When we have our own internal state - we don't need to wait for the pod deletion to complete // We can just mark is to delete in our state and return the lease - jobRunAttempted := issue.Issue.Type != UnableToSchedule - returnLeaseEvent := reporter.CreateReturnLeaseEvent(issue.Issue.OriginalPodState, issue.Issue.Message, p.clusterContext.GetClusterId(), jobRunAttempted) - err := p.eventReporter.Report([]reporter.EventMessage{{Event: returnLeaseEvent, JobRunId: issue.Issue.RunId}}) + jobRunAttempted := issue.RunIssue.PodIssue.Type != UnableToSchedule + returnLeaseEvent := reporter.CreateReturnLeaseEvent(issue.RunIssue.PodIssue.OriginalPodState, issue.RunIssue.PodIssue.Message, p.clusterContext.GetClusterId(), jobRunAttempted) + err := p.eventReporter.Report([]reporter.EventMessage{{Event: returnLeaseEvent, JobRunId: issue.RunIssue.RunId}}) if err != nil { - log.Errorf("Failed to return lease for job %s because %s", issue.Issue.JobId, err) + log.Errorf("Failed to return lease for job %s because %s", issue.RunIssue.JobId, err) return } - p.markIssuesResolved(issue.Issue) + p.markIssuesResolved(issue.RunIssue) } } func hasPodIssueSelfResolved(issue *issue) bool { - if issue == nil || issue.Issue == nil { + if issue == nil || issue.RunIssue == nil || issue.RunIssue.PodIssue == nil { return true } - isStuckStartingUpAndResolvable := issue.Issue.Type == StuckStartingUp && - (issue.Issue.Retryable || (!issue.Issue.Retryable && !issue.Issue.Reported)) - if issue.Issue.Type == UnableToSchedule || isStuckStartingUpAndResolvable { + isStuckStartingUpAndResolvable := issue.RunIssue.PodIssue.Type == StuckStartingUp && + (issue.RunIssue.PodIssue.Retryable || (!issue.RunIssue.PodIssue.Retryable && !issue.RunIssue.Reported)) + if issue.RunIssue.PodIssue.Type == UnableToSchedule || isStuckStartingUpAndResolvable { // If pod has disappeared - don't consider it resolved as we still need to report the issue if issue.CurrentPodState == nil { return false } - // Pod has completed - no need to report any issues - if util.IsInTerminalState(issue.CurrentPodState) { - return true - } - - // Pod has started running, and we haven't requested deletion - let it continue - if issue.CurrentPodState.Status.Phase == v1.PodRunning && !issue.Issue.DeletionRequested { + // Pod has started up and we haven't tried to delete the pod yet - so resolve the issue + if issue.CurrentPodState.Status.Phase != v1.PodPending && !issue.RunIssue.PodIssue.DeletionRequested { return true } - // TODO There is an edge case here where the pod has started running but we have requested deletion - // Without a proper state model, we can't easily handle this correctly - // Ideally we'd see if it completes or deletes first and report it accordingly - // If it completes first - do nothing - // If it deletes first - report JobFailed (as we accidentally deleted it during the run) } return false @@ -344,19 +392,107 @@ func createStuckPodMessage(retryable bool, originalMessage string) string { return fmt.Sprintf("Unable to schedule pod with unrecoverable problem, Armada will not retry.\n%s", originalMessage) } -func (p *PodIssueService) handleDeletedPod(pod *v1.Pod) { +func (p *IssueHandler) handleDeletedPod(pod *v1.Pod) { jobId := util.ExtractJobId(pod) if jobId != "" { isUnexpectedDeletion := !util.IsMarkedForDeletion(pod) && !util.IsPodFinishedAndReported(pod) if isUnexpectedDeletion { - p.registerIssue(&podIssue{ - OriginalPodState: pod, - JobId: jobId, - RunId: util.ExtractJobRunId(pod), - Message: "Pod was unexpectedly deleted", - Retryable: false, - Reported: false, - Type: ExternallyDeleted, + p.registerIssue(&runIssue{ + JobId: jobId, + RunId: util.ExtractJobRunId(pod), + PodIssue: &podIssue{ + OriginalPodState: pod, + Message: "Pod was unexpectedly deleted", + Retryable: false, + Type: ExternallyDeleted, + }, + }) + } + } +} + +func (p *IssueHandler) handleReconciliationIssue(issue *issue) { + if issue.RunIssue.ReconciliationIssue == nil { + log.Warnf("unexpected trying to process an issue as a reconciliation issue for job %s run %s", issue.RunIssue.JobId, issue.RunIssue.RunId) + p.markIssuesResolved(issue.RunIssue) + return + } + + currentRunState := p.jobRunState.Get(issue.RunIssue.RunId) + if currentRunState == nil { + // No run for the run id - so there isn't a reconciliation issue + p.markIssuesResolved(issue.RunIssue) + return + } + + if issue.CurrentPodState != nil { + p.markIssuesResolved(issue.RunIssue) + return + } + + if issue.RunIssue.ReconciliationIssue.OriginalRunState.Phase != currentRunState.Phase || currentRunState.CancelRequested || currentRunState.PreemptionRequested { + // State of the run has changed - resolve + // If there is still an issue, it'll be re-detected + p.markIssuesResolved(issue.RunIssue) + return + } + + timeSinceInitialDetection := p.clock.Now().Sub(issue.RunIssue.ReconciliationIssue.InitialDetectionTime) + + // If there is an active run and the associated pod has been missing for more than a given time period, report the run as failed + if currentRunState.Phase == job.Active && timeSinceInitialDetection > p.stateChecksConfig.DeadlineForActivePodConsideredMissing { + log.Infof("Pod missing for active run detected for job %s run %s", issue.RunIssue.JobId, issue.RunIssue.RunId) + + event := &api.JobFailedEvent{ + JobId: currentRunState.Meta.JobId, + JobSetId: currentRunState.Meta.JobSet, + Queue: currentRunState.Meta.Queue, + Created: p.clock.Now(), + ClusterId: p.clusterContext.GetClusterId(), + Reason: fmt.Sprintf("Pod is unexpectedly missing in Kubernetes"), + Cause: api.Cause_Error, + } + + err := p.eventReporter.Report([]reporter.EventMessage{{Event: event, JobRunId: issue.RunIssue.RunId}}) + if err != nil { + log.Errorf("Failure to report failed event %+v because %s", event, err) + return + } + + p.markIssueReported(issue.RunIssue) + p.markIssuesResolved(issue.RunIssue) + } else if currentRunState.Phase == job.SuccessfulSubmission && timeSinceInitialDetection > p.stateChecksConfig.DeadlineForSubmittedPodConsideredMissing { + // If a pod hasn't shown up after a successful submission for a given time period, delete it from the run state + // This will cause it to be re-leased and submitted again + // If the issue is we are out of sync with kubernetes, the second submission will fail and kill the job + p.jobRunState.Delete(currentRunState.Meta.RunId) + p.markIssuesResolved(issue.RunIssue) + } +} + +func (p *IssueHandler) detectReconciliationIssues(pods []*v1.Pod) { + runs := p.jobRunState.GetAllWithFilter(func(state *job.RunState) bool { + return (state.Phase == job.Active || state.Phase == job.SuccessfulSubmission) && !state.CancelRequested && !state.PreemptionRequested + }) + + runIdsToPod := make(map[string]*v1.Pod, len(pods)) + for _, pod := range pods { + runId := util.ExtractJobRunId(pod) + if runId != "" { + runIdsToPod[runId] = pod + } + } + + for _, run := range runs { + _, present := runIdsToPod[run.Meta.RunId] + if !present { + p.registerIssue(&runIssue{ + JobId: run.Meta.JobId, + RunId: run.Meta.RunId, + ReconciliationIssue: &reconciliationIssue{ + InitialDetectionTime: p.clock.Now(), + OriginalRunState: run.DeepCopy(), + }, }) } } diff --git a/internal/executor/service/pod_issue_handler_test.go b/internal/executor/service/pod_issue_handler_test.go index bab9ea8bb2c..ccb8226d43d 100644 --- a/internal/executor/service/pod_issue_handler_test.go +++ b/internal/executor/service/pod_issue_handler_test.go @@ -6,8 +6,11 @@ import ( "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/executor/configuration" fakecontext "github.com/armadaproject/armada/internal/executor/context/fake" + "github.com/armadaproject/armada/internal/executor/job" "github.com/armadaproject/armada/internal/executor/reporter" "github.com/armadaproject/armada/internal/executor/reporter/mocks" "github.com/armadaproject/armada/internal/executor/util" @@ -15,7 +18,7 @@ import ( ) func TestPodIssueService_DoesNothingIfNoPodsAreFound(t *testing.T) { - podIssueService, _, eventsReporter := setupTestComponents() + podIssueService, _, _, eventsReporter := setupTestComponents([]*job.RunState{}) podIssueService.HandlePodIssues() @@ -23,7 +26,7 @@ func TestPodIssueService_DoesNothingIfNoPodsAreFound(t *testing.T) { } func TestPodIssueService_DoesNothingIfNoStuckPodsAreFound(t *testing.T) { - podIssueService, fakeClusterContext, eventsReporter := setupTestComponents() + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) runningPod := makeRunningPod(false) addPod(t, fakeClusterContext, runningPod) @@ -35,7 +38,7 @@ func TestPodIssueService_DoesNothingIfNoStuckPodsAreFound(t *testing.T) { } func TestPodIssueService_DeletesPodAndReportsFailed_IfStuckAndUnretryable(t *testing.T) { - podIssueService, fakeClusterContext, eventsReporter := setupTestComponents() + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) unretryableStuckPod := makeUnretryableStuckPod(false) addPod(t, fakeClusterContext, unretryableStuckPod) @@ -54,7 +57,7 @@ func TestPodIssueService_DeletesPodAndReportsFailed_IfStuckAndUnretryable(t *tes } func TestPodIssueService_DeletesPodAndReportsFailed_IfStuckTerminating(t *testing.T) { - podIssueService, fakeClusterContext, eventsReporter := setupTestComponents() + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) terminatingPod := makeTerminatingPod(false) addPod(t, fakeClusterContext, terminatingPod) @@ -70,7 +73,7 @@ func TestPodIssueService_DeletesPodAndReportsFailed_IfStuckTerminating(t *testin } func TestPodIssueService_DeletesPodAndReportsLeaseReturned_IfRetryableStuckPod(t *testing.T) { - podIssueService, fakeClusterContext, eventsReporter := setupTestComponents() + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) retryableStuckPod := makeRetryableStuckPod(false) addPod(t, fakeClusterContext, retryableStuckPod) @@ -94,8 +97,39 @@ func TestPodIssueService_DeletesPodAndReportsLeaseReturned_IfRetryableStuckPod(t assert.True(t, ok) } +func TestPodIssueService_DeletesPodAndReportsFailed_IfRetryableStuckPodStartsUpAfterDeletionCalled(t *testing.T) { + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) + retryableStuckPod := makeRetryableStuckPod(false) + addPod(t, fakeClusterContext, retryableStuckPod) + + podIssueService.HandlePodIssues() + + // Reports UnableToSchedule + assert.Len(t, eventsReporter.ReceivedEvents, 1) + _, ok := eventsReporter.ReceivedEvents[0].Event.(*api.JobUnableToScheduleEvent) + assert.True(t, ok) + + // Reset events, and add pod back as running + eventsReporter.ReceivedEvents = []reporter.EventMessage{} + retryableStuckPod.Status.Phase = v1.PodRunning + addPod(t, fakeClusterContext, retryableStuckPod) + + // Detects pod is now unexpectedly running and marks it non-retryable + podIssueService.HandlePodIssues() + assert.Len(t, eventsReporter.ReceivedEvents, 0) + assert.Len(t, getActivePods(t, fakeClusterContext), 1) + + // Now processes the issue as non-retryable and fails the pod + podIssueService.HandlePodIssues() + assert.Len(t, getActivePods(t, fakeClusterContext), 0) + + assert.Len(t, eventsReporter.ReceivedEvents, 1) + _, ok = eventsReporter.ReceivedEvents[0].Event.(*api.JobFailedEvent) + assert.True(t, ok) +} + func TestPodIssueService_ReportsFailed_IfDeletedExternally(t *testing.T) { - podIssueService, fakeClusterContext, eventsReporter := setupTestComponents() + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{}) runningPod := makeRunningPod(false) fakeClusterContext.SimulateDeletionEvent(runningPod) @@ -108,17 +142,108 @@ func TestPodIssueService_ReportsFailed_IfDeletedExternally(t *testing.T) { assert.Equal(t, failedEvent.JobId, util.ExtractJobId(runningPod)) } -func setupTestComponents() (*PodIssueService, *fakecontext.SyncFakeClusterContext, *mocks.FakeEventReporter) { +func TestPodIssueService_ReportsFailed_IfPodOfActiveRunGoesMissing(t *testing.T) { + baseTime := time.Now() + fakeClock := clock.NewFakeClock(baseTime) + podIssueService, _, _, eventsReporter := setupTestComponents([]*job.RunState{createRunState("job-1", "run-1", job.Active)}) + podIssueService.clock = fakeClock + + podIssueService.HandlePodIssues() + // Nothing should happen, until the issue has been seen for a configured amount of time + assert.Len(t, eventsReporter.ReceivedEvents, 0) + + fakeClock.SetTime(baseTime.Add(10 * time.Minute)) + podIssueService.HandlePodIssues() + // Reports Failed + assert.Len(t, eventsReporter.ReceivedEvents, 1) + failedEvent, ok := eventsReporter.ReceivedEvents[0].Event.(*api.JobFailedEvent) + assert.True(t, ok) + assert.Equal(t, failedEvent.JobId, "job-1") +} + +func TestPodIssueService_DoesNothing_IfMissingPodOfActiveRunReturns(t *testing.T) { + baseTime := time.Now() + fakeClock := clock.NewFakeClock(baseTime) + runningPod := makeRunningPod(false) + runState := createRunState(util.ExtractJobId(runningPod), util.ExtractJobRunId(runningPod), job.Active) + podIssueService, _, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{runState}) + podIssueService.clock = fakeClock + + podIssueService.HandlePodIssues() + // Nothing should happen, until the issue has been seen for a configured amount of time + assert.Len(t, eventsReporter.ReceivedEvents, 0) + + addPod(t, fakeClusterContext, runningPod) + fakeClock.SetTime(baseTime.Add(10 * time.Minute)) + podIssueService.HandlePodIssues() + assert.Len(t, eventsReporter.ReceivedEvents, 0) +} + +func TestPodIssueService_DeleteRunFromRunState_IfSubmittedPodNeverAppears(t *testing.T) { + baseTime := time.Now() + fakeClock := clock.NewFakeClock(baseTime) + podIssueService, runStateStore, _, eventsReporter := setupTestComponents([]*job.RunState{createRunState("job-1", "run-1", job.SuccessfulSubmission)}) + podIssueService.clock = fakeClock + + podIssueService.HandlePodIssues() + // Nothing should happen, until the issue has been seen for a configured amount of time + assert.Len(t, eventsReporter.ReceivedEvents, 0) + assert.Len(t, runStateStore.GetAll(), 1) + + fakeClock.SetTime(baseTime.Add(20 * time.Minute)) + podIssueService.HandlePodIssues() + assert.Len(t, eventsReporter.ReceivedEvents, 0) + // Pod has been missing for greater than configured period, run should get deleted + assert.Len(t, runStateStore.GetAll(), 0) +} + +func TestPodIssueService_DoesNothing_IfSubmittedPodAppears(t *testing.T) { + baseTime := time.Now() + fakeClock := clock.NewFakeClock(baseTime) + runningPod := makeRunningPod(false) + runState := createRunState(util.ExtractJobId(runningPod), util.ExtractJobRunId(runningPod), job.SuccessfulSubmission) + podIssueService, runStateStore, fakeClusterContext, eventsReporter := setupTestComponents([]*job.RunState{runState}) + podIssueService.clock = fakeClock + + podIssueService.HandlePodIssues() + // Nothing should happen, until the issue has been seen for a configured amount of time + assert.Len(t, eventsReporter.ReceivedEvents, 0) + assert.Len(t, runStateStore.GetAll(), 1) + + addPod(t, fakeClusterContext, runningPod) + fakeClock.SetTime(baseTime.Add(20 * time.Minute)) + podIssueService.HandlePodIssues() + assert.Len(t, runStateStore.GetAll(), 1) +} + +func setupTestComponents(initialRunState []*job.RunState) (*IssueHandler, *job.JobRunStateStore, *fakecontext.SyncFakeClusterContext, *mocks.FakeEventReporter) { fakeClusterContext := fakecontext.NewSyncFakeClusterContext() eventReporter := mocks.NewFakeEventReporter() pendingPodChecker := makePodChecker() - - podIssueHandler := NewPodIssueService( + runStateStore := job.NewJobRunStateStoreWithInitialState(initialRunState) + stateChecksConfig := configuration.StateChecksConfiguration{ + DeadlineForSubmittedPodConsideredMissing: time.Minute * 15, + DeadlineForActivePodConsideredMissing: time.Minute * 5, + } + + podIssueHandler := NewIssueHandler( + runStateStore, fakeClusterContext, eventReporter, + stateChecksConfig, pendingPodChecker, time.Minute*3, ) - return podIssueHandler, fakeClusterContext, eventReporter + return podIssueHandler, runStateStore, fakeClusterContext, eventReporter +} + +func createRunState(jobId string, runId string, phase job.RunPhase) *job.RunState { + return &job.RunState{ + Phase: phase, + Meta: &job.RunMeta{ + JobId: jobId, + RunId: runId, + }, + } } diff --git a/internal/lookoutv2/conversions/convert_test.go b/internal/lookoutv2/conversions/convert_test.go index 9d5649156ac..32130e63ff3 100644 --- a/internal/lookoutv2/conversions/convert_test.go +++ b/internal/lookoutv2/conversions/convert_test.go @@ -86,16 +86,22 @@ var ( } swaggerGroup = &models.Group{ - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "averageTimeInState": "3d", + "state": map[string]int{ + "QUEUED": 321, + }, }, Count: 1000, Name: "queue-1", } group = &model.JobGroup{ - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "averageTimeInState": "3d", + "state": map[string]int{ + "QUEUED": 321, + }, }, Count: 1000, Name: "queue-1", diff --git a/internal/lookoutv2/gen/models/group.go b/internal/lookoutv2/gen/models/group.go index 71adda73be1..25c8d68892a 100644 --- a/internal/lookoutv2/gen/models/group.go +++ b/internal/lookoutv2/gen/models/group.go @@ -21,7 +21,7 @@ type Group struct { // aggregates // Required: true - Aggregates map[string]string `json:"aggregates"` + Aggregates map[string]interface{} `json:"aggregates"` // count // Required: true @@ -61,6 +61,14 @@ func (m *Group) validateAggregates(formats strfmt.Registry) error { return err } + for k := range m.Aggregates { + + if err := validate.Required("aggregates"+"."+k, "body", m.Aggregates[k]); err != nil { + return err + } + + } + return nil } diff --git a/internal/lookoutv2/gen/restapi/embedded_spec.go b/internal/lookoutv2/gen/restapi/embedded_spec.go index 629d7b45500..c57b6290da2 100644 --- a/internal/lookoutv2/gen/restapi/embedded_spec.go +++ b/internal/lookoutv2/gen/restapi/embedded_spec.go @@ -426,7 +426,7 @@ func init() { "aggregates": { "type": "object", "additionalProperties": { - "type": "string" + "type": "object" }, "x-nullable": false }, @@ -1082,7 +1082,7 @@ func init() { "aggregates": { "type": "object", "additionalProperties": { - "type": "string" + "type": "object" }, "x-nullable": false }, diff --git a/internal/lookoutv2/model/model.go b/internal/lookoutv2/model/model.go index 349541d54cb..0d22f87ec3c 100644 --- a/internal/lookoutv2/model/model.go +++ b/internal/lookoutv2/model/model.go @@ -53,7 +53,7 @@ type Run struct { } type JobGroup struct { - Aggregates map[string]string + Aggregates map[string]interface{} Count int64 Name string } diff --git a/internal/lookoutv2/repository/aggregates.go b/internal/lookoutv2/repository/aggregates.go new file mode 100644 index 00000000000..ad7c1386dba --- /dev/null +++ b/internal/lookoutv2/repository/aggregates.go @@ -0,0 +1,133 @@ +package repository + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/database/lookout" + "github.com/armadaproject/armada/internal/common/util" + "github.com/armadaproject/armada/internal/lookoutv2/model" +) + +type QueryAggregator interface { + AggregateSql() (string, error) +} + +type SqlFunctionAggregator struct { + queryCol *queryColumn + sqlFunction string +} + +func NewSqlFunctionAggregator(queryCol *queryColumn, fn string) *SqlFunctionAggregator { + return &SqlFunctionAggregator{ + queryCol: queryCol, + sqlFunction: fn, + } +} + +func (qa *SqlFunctionAggregator) aggregateColName() string { + return qa.queryCol.name +} + +func (qa *SqlFunctionAggregator) AggregateSql() (string, error) { + return fmt.Sprintf("%s(%s.%s) AS %s", qa.sqlFunction, qa.queryCol.abbrev, qa.queryCol.name, qa.aggregateColName()), nil +} + +type StateCountAggregator struct { + queryCol *queryColumn + stateString string +} + +func NewStateCountAggregator(queryCol *queryColumn, stateString string) *StateCountAggregator { + return &StateCountAggregator{ + queryCol: queryCol, + stateString: stateString, + } +} + +func (qa *StateCountAggregator) aggregateColName() string { + return fmt.Sprintf("%s_%s", qa.queryCol.name, qa.stateString) +} + +func (qa *StateCountAggregator) AggregateSql() (string, error) { + stateInt, ok := lookout.JobStateOrdinalMap[lookout.JobState(qa.stateString)] + if !ok { + return "", errors.Errorf("state %s does not exist", qa.stateString) + } + return fmt.Sprintf( + "SUM(CASE WHEN %s.%s = %d THEN 1 ELSE 0 END) AS %s", + qa.queryCol.abbrev, qa.queryCol.name, stateInt, qa.aggregateColName(), + ), nil +} + +func GetAggregatorsForColumn(queryCol *queryColumn, aggregateType AggregateType, filters []*model.Filter) ([]QueryAggregator, error) { + switch aggregateType { + case Max: + return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "MAX")}, nil + case Average: + return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "AVG")}, nil + case StateCounts: + states := GetStatesForFilter(filters) + aggregators := make([]QueryAggregator, len(states)) + for i, state := range states { + aggregators[i] = NewStateCountAggregator(queryCol, state) + } + return aggregators, nil + default: + return nil, errors.Errorf("cannot determine aggregate type: %v", aggregateType) + } +} + +// GetStatesForFilter returns a list of states as string if filter for state exists +// Will always return the states in the same order, irrespective of the ordering of the states in the filter +func GetStatesForFilter(filters []*model.Filter) []string { + var stateFilter *model.Filter + for _, f := range filters { + if f.Field == stateField { + stateFilter = f + } + } + allStates := util.Map(lookout.JobStates, func(jobState lookout.JobState) string { return string(jobState) }) + if stateFilter == nil { + // If no state filter is specified, use all states + return allStates + } + + switch stateFilter.Match { + case model.MatchExact: + return []string{fmt.Sprintf("%s", stateFilter.Value)} + case model.MatchAnyOf: + strSlice, err := toStringSlice(stateFilter.Value) + if err != nil { + return allStates + } + stateStringSet := util.StringListToSet(strSlice) + // Ensuring they are in the same order + var finalStates []string + for _, state := range allStates { + if _, ok := stateStringSet[state]; ok { + finalStates = append(finalStates, state) + } + } + return finalStates + default: + return allStates + } +} + +func toStringSlice(val interface{}) ([]string, error) { + switch v := val.(type) { + case []string: + return v, nil + case []interface{}: + result := make([]string, len(v)) + for i := 0; i < len(v); i++ { + str := fmt.Sprintf("%v", v[i]) + result[i] = str + } + return result, nil + default: + return nil, errors.Errorf("failed to convert interface to string slice: %v of type %T", val, val) + } +} diff --git a/internal/lookoutv2/repository/fieldparser.go b/internal/lookoutv2/repository/fieldparser.go new file mode 100644 index 00000000000..e8ddde0996b --- /dev/null +++ b/internal/lookoutv2/repository/fieldparser.go @@ -0,0 +1,122 @@ +package repository + +import ( + "fmt" + "math" + "time" + + "github.com/jackc/pgtype" + "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/database/lookout" + "github.com/armadaproject/armada/internal/lookoutv2/model" +) + +type FieldParser interface { + GetField() string + GetVariableRef() interface{} + ParseValue() (interface{}, error) +} + +type LastTransitionTimeParser struct { + variable pgtype.Numeric +} + +func (fp *LastTransitionTimeParser) GetField() string { + return lastTransitionTimeField +} + +func (fp *LastTransitionTimeParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *LastTransitionTimeParser) ParseValue() (interface{}, error) { + var dst float64 + err := fp.variable.AssignTo(&dst) + if err != nil { + return "", err + } + t := time.Unix(int64(math.Round(dst)), 0) + return t.Format(time.RFC3339), nil +} + +type TimeParser struct { + field string + variable time.Time +} + +func (fp *TimeParser) GetField() string { + return fp.field +} + +func (fp *TimeParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *TimeParser) ParseValue() (interface{}, error) { + return fp.variable.Format(time.RFC3339), nil +} + +type StateParser struct { + variable int16 +} + +func (fp *StateParser) GetField() string { + return stateField +} + +func (fp *StateParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *StateParser) ParseValue() (interface{}, error) { + state, ok := lookout.JobStateMap[int(fp.variable)] + if !ok { + return "", errors.Errorf("state not found: %d", fp.variable) + } + return string(state), nil +} + +type BasicParser[T any] struct { + field string + variable T +} + +func (fp *BasicParser[T]) GetField() string { + return fp.field +} + +func (fp *BasicParser[T]) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *BasicParser[T]) ParseValue() (interface{}, error) { + return fp.variable, nil +} + +func ParserForGroup(field string) FieldParser { + switch field { + case stateField: + return &StateParser{} + default: + return &BasicParser[string]{field: field} + } +} + +func ParsersForAggregate(field string, filters []*model.Filter) ([]FieldParser, error) { + var parsers []FieldParser + switch field { + case lastTransitionTimeField: + parsers = append(parsers, &LastTransitionTimeParser{}) + case submittedField: + parsers = append(parsers, &TimeParser{field: submittedField}) + case stateField: + states := GetStatesForFilter(filters) + for _, state := range states { + parsers = append(parsers, &BasicParser[int]{field: fmt.Sprintf("%s%s", stateAggregatePrefix, state)}) + } + default: + return nil, errors.Errorf("no aggregate found for field %s", field) + } + return parsers, nil +} diff --git a/internal/lookoutv2/repository/groupjobs.go b/internal/lookoutv2/repository/groupjobs.go index 1988e4a31ce..f8fe0b37206 100644 --- a/internal/lookoutv2/repository/groupjobs.go +++ b/internal/lookoutv2/repository/groupjobs.go @@ -2,16 +2,14 @@ package repository import ( "context" - "math" - "time" + "fmt" + "strings" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/pkg/errors" "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" ) @@ -39,15 +37,7 @@ type SqlGroupJobsRepository struct { lookoutTables *LookoutTables } -type scanVarInit func() interface{} - -type parserFn func(interface{}) (string, error) - -type scanContext struct { - field string - varInit scanVarInit - parser parserFn -} +const stateAggregatePrefix = "state_" func NewSqlGroupJobsRepository(db *pgxpool.Pool) *SqlGroupJobsRepository { return &SqlGroupJobsRepository{ @@ -95,7 +85,7 @@ func (r *SqlGroupJobsRepository) GroupBy( if err != nil { return err } - groups, err = rowsToGroups(groupRows, groupedField, aggregates) + groups, err = rowsToGroups(groupRows, groupedField, aggregates, filters) return err }) if err != nil { @@ -108,10 +98,10 @@ func (r *SqlGroupJobsRepository) GroupBy( }, nil } -func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates []string) ([]*model.JobGroup, error) { +func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates []string, filters []*model.Filter) ([]*model.JobGroup, error) { var groups []*model.JobGroup for rows.Next() { - jobGroup, err := scanGroup(rows, groupedField.Field, aggregates) + jobGroup, err := scanGroup(rows, groupedField.Field, aggregates, filters) if err != nil { return nil, err } @@ -120,143 +110,59 @@ func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates [] return groups, nil } -func scanGroup(rows pgx.Rows, field string, aggregates []string) (*model.JobGroup, error) { - groupScanContext, err := groupScanContextForField(field) - if err != nil { - return nil, err - } - group := groupScanContext.varInit() +func scanGroup(rows pgx.Rows, field string, aggregates []string, filters []*model.Filter) (*model.JobGroup, error) { + groupParser := ParserForGroup(field) var count int64 - - scanContexts := make([]*scanContext, len(aggregates)) - aggregateVars := make([]interface{}, len(aggregates)) - for i, aggregate := range aggregates { - sc, err := aggregateScanContextForField(aggregate) + var aggregateParsers []FieldParser + for _, aggregate := range aggregates { + parsers, err := ParsersForAggregate(aggregate, filters) if err != nil { return nil, err } - aggregateVars[i] = sc.varInit() - scanContexts[i] = sc + aggregateParsers = append(aggregateParsers, parsers...) } - aggregateRefs := make([]interface{}, len(aggregates)) - for i := 0; i < len(aggregates); i++ { - aggregateRefs[i] = &aggregateVars[i] + aggregateRefs := make([]interface{}, len(aggregateParsers)) + for i, parser := range aggregateParsers { + aggregateRefs[i] = parser.GetVariableRef() } - varAddresses := util.Concat([]interface{}{&group, &count}, aggregateRefs) - err = rows.Scan(varAddresses...) + varAddresses := util.Concat([]interface{}{groupParser.GetVariableRef(), &count}, aggregateRefs) + err := rows.Scan(varAddresses...) if err != nil { return nil, err } - parsedGroup, err := groupScanContext.parser(group) + parsedGroup, err := groupParser.ParseValue() if err != nil { return nil, err } - aggregatesMap := make(map[string]string) - for i, sc := range scanContexts { - val := aggregateVars[i] - parsedVal, err := sc.parser(val) + aggregatesMap := make(map[string]interface{}) + for _, parser := range aggregateParsers { + val, err := parser.ParseValue() if err != nil { - return nil, errors.Wrapf(err, "failed to parse value for field %s", sc.field) + return nil, errors.Wrapf(err, "failed to parse value for field %s", parser.GetField()) + } + if strings.HasPrefix(parser.GetField(), stateAggregatePrefix) { + singleStateCount, ok := val.(int) + if !ok { + return nil, errors.Errorf("failed to parse value for state aggregate: cannot convert value to int: %v: %T", singleStateCount, singleStateCount) + } + stateCountsVal, ok := aggregatesMap[stateField] + if !ok { + stateCountsVal = map[string]int{} + aggregatesMap[stateField] = stateCountsVal + } + stateCounts, ok := stateCountsVal.(map[string]int) + if !ok { + return nil, errors.Errorf("failed to parse value for state aggregate: cannot cast state counts to map") + } + state := parser.GetField()[len(stateAggregatePrefix):] + stateCounts[state] = singleStateCount + } else { + aggregatesMap[parser.GetField()] = val } - aggregatesMap[sc.field] = parsedVal } return &model.JobGroup{ - Name: parsedGroup, + Name: fmt.Sprintf("%s", parsedGroup), Count: count, Aggregates: aggregatesMap, }, nil } - -func groupScanContextForField(field string) (*scanContext, error) { - switch field { - case stateField: - return &scanContext{ - field: field, - varInit: int16ScanVar, - parser: stateParser, - }, nil - default: - return &scanContext{ - field: field, - varInit: stringScanVar, - parser: stringParser, - }, nil - } -} - -func aggregateScanContextForField(field string) (*scanContext, error) { - switch field { - case lastTransitionTimeField: - return &scanContext{ - field: lastTransitionTimeField, - varInit: numericScanVar, - parser: avgLastTransitionTimeParser, - }, nil - case submittedField: - return &scanContext{ - field: submittedField, - varInit: timeScanVar, - parser: maxSubmittedTimeParser, - }, nil - default: - return nil, errors.Errorf("no aggregate found for field %s", field) - } -} - -func stringScanVar() interface{} { - return "" -} - -func int16ScanVar() interface{} { - return int16(0) -} - -func numericScanVar() interface{} { - return pgtype.Numeric{} -} - -func timeScanVar() interface{} { - return time.Time{} -} - -func avgLastTransitionTimeParser(val interface{}) (string, error) { - lastTransitionTimeSeconds, ok := val.(pgtype.Numeric) - if !ok { - return "", errors.Errorf("could not convert %v: %T to int64", val, val) - } - var dst float64 - err := lastTransitionTimeSeconds.AssignTo(&dst) - if err != nil { - return "", err - } - t := time.Unix(int64(math.Round(dst)), 0) - return t.Format(time.RFC3339), nil -} - -func maxSubmittedTimeParser(val interface{}) (string, error) { - maxSubmittedTime, ok := val.(time.Time) - if !ok { - return "", errors.Errorf("could not convert %v: %T to time", val, val) - } - return maxSubmittedTime.Format(time.RFC3339), nil -} - -func stateParser(val interface{}) (string, error) { - stateInt, ok := val.(int16) - if !ok { - return "", errors.Errorf("could not convert %v: %T to int for state", val, val) - } - state, ok := lookout.JobStateMap[int(stateInt)] - if !ok { - return "", errors.Errorf("state not found: %d", stateInt) - } - return string(state), nil -} - -func stringParser(val interface{}) (string, error) { - str, ok := val.(string) - if !ok { - return "", errors.Errorf("could not convert %v: %T to string", val, val) - } - return str, nil -} diff --git a/internal/lookoutv2/repository/groupjobs_test.go b/internal/lookoutv2/repository/groupjobs_test.go index 29fb24a507c..2ca98fd8a26 100644 --- a/internal/lookoutv2/repository/groupjobs_test.go +++ b/internal/lookoutv2/repository/groupjobs_test.go @@ -59,17 +59,17 @@ func TestGroupByQueue(t *testing.T) { { Name: "queue-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "queue-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "queue-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -117,17 +117,17 @@ func TestGroupByJobSet(t *testing.T) { { Name: "job-set-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "job-set-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "job-set-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -183,22 +183,22 @@ func TestGroupByState(t *testing.T) { { Name: string(lookout.JobQueued), Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobPending), Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobRunning), Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobFailed), Count: 2, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -370,22 +370,22 @@ func TestGroupByWithFilters(t *testing.T) { { Name: string(lookout.JobQueued), Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobPending), Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobRunning), Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobFailed), Count: 2, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -468,21 +468,21 @@ func TestGroupJobsWithMaxSubmittedTime(t *testing.T) { { Name: "job-set-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Format(time.RFC3339), }, }, { Name: "job-set-2", Count: 12, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(-4 * time.Minute).Format(time.RFC3339), }, }, { Name: "job-set-3", Count: 18, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(-7 * time.Minute).Format(time.RFC3339), }, }, @@ -567,21 +567,21 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { { Name: "queue-3", Count: 18, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-8 * time.Minute).Format(time.RFC3339), }, }, { Name: "queue-2", Count: 12, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-5 * time.Minute).Format(time.RFC3339), }, }, { Name: "queue-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-1 * time.Minute).Format(time.RFC3339), }, }, @@ -591,6 +591,237 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { assert.NoError(t, err) } +func TestGroupJobsWithAllStateCounts(t *testing.T) { + err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { + converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, false) + store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) + + manyJobs(5, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobQueued, + }, converter, store) + manyJobs(6, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobPending, + }, converter, store) + manyJobs(7, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobRunning, + }, converter, store) + + manyJobs(8, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobLeased, + }, converter, store) + manyJobs(9, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobPreempted, + }, converter, store) + manyJobs(10, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobCancelled, + }, converter, store) + + manyJobs(11, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobSucceeded, + }, converter, store) + manyJobs(12, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobFailed, + }, converter, store) + manyJobs(13, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobQueued, + }, converter, store) + + repo := NewSqlGroupJobsRepository(db) + result, err := repo.GroupBy( + context.TODO(), + []*model.Filter{}, + &model.Order{ + Field: "count", + Direction: "ASC", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{"state"}, + 0, + 10, + ) + assert.NoError(t, err) + assert.Len(t, result.Groups, 3) + assert.Equal(t, 3, result.Count) + assert.Equal(t, []*model.JobGroup{ + { + Name: "job-set-1", + Count: 18, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 5, + string(lookout.JobLeased): 0, + string(lookout.JobPending): 6, + string(lookout.JobRunning): 7, + string(lookout.JobSucceeded): 0, + string(lookout.JobFailed): 0, + string(lookout.JobCancelled): 0, + string(lookout.JobPreempted): 0, + }, + }, + }, + { + Name: "job-set-2", + Count: 27, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 0, + string(lookout.JobLeased): 8, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + string(lookout.JobSucceeded): 0, + string(lookout.JobFailed): 0, + string(lookout.JobCancelled): 10, + string(lookout.JobPreempted): 9, + }, + }, + }, + { + Name: "job-set-3", + Count: 36, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 13, + string(lookout.JobLeased): 0, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + string(lookout.JobSucceeded): 11, + string(lookout.JobFailed): 12, + string(lookout.JobCancelled): 0, + string(lookout.JobPreempted): 0, + }, + }, + }, + }, result.Groups) + return nil + }) + assert.NoError(t, err) +} + +func TestGroupJobsWithFilteredStateCounts(t *testing.T) { + err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { + converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, false) + store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) + + manyJobs(5, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobQueued, + }, converter, store) + manyJobs(6, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobPending, + }, converter, store) + manyJobs(7, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobRunning, + }, converter, store) + + manyJobs(9, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobPreempted, + }, converter, store) + manyJobs(10, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobCancelled, + }, converter, store) + + manyJobs(11, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobSucceeded, + }, converter, store) + manyJobs(12, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobFailed, + }, converter, store) + manyJobs(13, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobQueued, + }, converter, store) + + repo := NewSqlGroupJobsRepository(db) + result, err := repo.GroupBy( + context.TODO(), + []*model.Filter{ + { + Field: stateField, + Match: model.MatchAnyOf, + Value: []string{ + string(lookout.JobQueued), + string(lookout.JobPending), + string(lookout.JobRunning), + }, + }, + }, + &model.Order{ + Field: "count", + Direction: "DESC", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{"state"}, + 0, + 10, + ) + assert.NoError(t, err) + assert.Len(t, result.Groups, 2) + assert.Equal(t, 2, result.Count) + assert.Equal(t, []*model.JobGroup{ + { + Name: "job-set-1", + Count: 18, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 5, + string(lookout.JobPending): 6, + string(lookout.JobRunning): 7, + }, + }, + }, + { + Name: "job-set-3", + Count: 13, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 13, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + }, + }, + }, + }, result.Groups) + return nil + }) + assert.NoError(t, err) +} + func TestGroupJobsComplex(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) @@ -709,7 +940,7 @@ func TestGroupJobsComplex(t *testing.T) { { Name: "job-set-2", Count: 2, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(20 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(50 * time.Minute).Format(time.RFC3339), }, @@ -717,7 +948,7 @@ func TestGroupJobsComplex(t *testing.T) { { Name: "job-set-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(3 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(5 * time.Minute).Format(time.RFC3339), }, @@ -778,17 +1009,17 @@ func TestGroupByAnnotation(t *testing.T) { { Name: "test-value-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "test-value-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "test-value-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -907,7 +1138,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "4", Count: 2, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(20 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(50 * time.Minute).Format(time.RFC3339), }, @@ -915,7 +1146,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "2", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(1 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(10 * time.Minute).Format(time.RFC3339), }, @@ -923,7 +1154,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "3", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(3 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(5 * time.Minute).Format(time.RFC3339), }, @@ -931,7 +1162,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "1", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Format(time.RFC3339), "lastTransitionTime": baseTime.Format(time.RFC3339), }, @@ -960,7 +1191,7 @@ func TestGroupJobsSkip(t *testing.T) { return &model.JobGroup{ Name: fmt.Sprintf("queue-%d", i), Count: int64(i), - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, } } @@ -1160,12 +1391,20 @@ func getCreateJobsFn(state lookout.JobState) createJobsFn { switch state { case lookout.JobQueued: return makeQueued + case lookout.JobLeased: + return makeLeased case lookout.JobPending: return makePending case lookout.JobRunning: return makeRunning + case lookout.JobSucceeded: + return makeSucceeded case lookout.JobFailed: return makeFailed + case lookout.JobCancelled: + return makeCancelled + case lookout.JobPreempted: + return makePreempted default: return makeQueued } @@ -1186,6 +1425,23 @@ func makeQueued(opts *createJobsOpts, converter *instructions.InstructionConvert Build() } +func makeLeased(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Lease(uuid.NewString(), lastTransitionTime). + Build() +} + func makePending(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { tSubmit := baseTime if opts.submittedTime != nil { @@ -1222,6 +1478,27 @@ func makeRunning(opts *createJobsOpts, converter *instructions.InstructionConver Build() } +func makeSucceeded(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + runId := uuid.NewString() + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Pending(runId, cluster, lastTransitionTime.Add(-2*time.Minute)). + Running(runId, cluster, lastTransitionTime.Add(-1*time.Minute)). + RunSucceeded(runId, lastTransitionTime). + Succeeded(lastTransitionTime). + Build() +} + func makeFailed(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { tSubmit := baseTime if opts.submittedTime != nil { @@ -1242,3 +1519,40 @@ func makeFailed(opts *createJobsOpts, converter *instructions.InstructionConvert Failed(node, 1, "error", lastTransitionTime). Build() } + +func makeCancelled(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Cancelled(lastTransitionTime). + Build() +} + +func makePreempted(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + runId := uuid.NewString() + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Pending(runId, cluster, lastTransitionTime.Add(-2*time.Minute)). + Running(runId, cluster, lastTransitionTime.Add(-1*time.Minute)). + Preempted(lastTransitionTime). + Build() +} diff --git a/internal/lookoutv2/repository/common.go b/internal/lookoutv2/repository/querybuilder.go similarity index 95% rename from internal/lookoutv2/repository/common.go rename to internal/lookoutv2/repository/querybuilder.go index 33e1725db02..c0999dbd5dd 100644 --- a/internal/lookoutv2/repository/common.go +++ b/internal/lookoutv2/repository/querybuilder.go @@ -58,14 +58,6 @@ type queryOrder struct { direction string } -// Get aggregation expression for column, e.g. MAX(j.submitted) -type aggregatorFn func(column *queryColumn) string - -type queryAggregator struct { - column *queryColumn - aggregator aggregatorFn -} - func NewQueryBuilder(lookoutTables *LookoutTables) *QueryBuilder { return &QueryBuilder{ lookoutTables: lookoutTables, @@ -368,11 +360,14 @@ func (qb *QueryBuilder) GroupBy( if err != nil { return nil, err } - queryAggregators, err := qb.getQueryAggregators(aggregates, queryTables) + queryAggregators, err := qb.getQueryAggregators(aggregates, normalFilters, queryTables) + if err != nil { + return nil, err + } + selectListSql, err := qb.getAggregatesSql(queryAggregators) if err != nil { return nil, err } - selectListSql := qb.getAggregatesSql(queryAggregators) orderSql, err := qb.groupByOrderSql(order) if err != nil { return nil, err @@ -912,9 +907,9 @@ func (qb *QueryBuilder) highestPrecedenceTableForColumn(col string, queryTables return selectedTable, nil } -func (qb *QueryBuilder) getQueryAggregators(aggregates []string, queryTables map[string]bool) ([]*queryAggregator, error) { - queryAggregators := make([]*queryAggregator, len(aggregates)) - for i, aggregate := range aggregates { +func (qb *QueryBuilder) getQueryAggregators(aggregates []string, filters []*model.Filter, queryTables map[string]bool) ([]QueryAggregator, error) { + var queryAggregators []QueryAggregator + for _, aggregate := range aggregates { col, err := qb.lookoutTables.ColumnFromField(aggregate) if err != nil { return nil, err @@ -927,25 +922,25 @@ func (qb *QueryBuilder) getQueryAggregators(aggregates []string, queryTables map if err != nil { return nil, err } - fn, err := getAggregatorFn(aggregateType) + newQueryAggregators, err := GetAggregatorsForColumn(qc, aggregateType, filters) if err != nil { return nil, err } - queryAggregators[i] = &queryAggregator{ - column: qc, - aggregator: fn, - } + queryAggregators = append(queryAggregators, newQueryAggregators...) } return queryAggregators, nil } -func (qb *QueryBuilder) getAggregatesSql(aggregators []*queryAggregator) string { +func (qb *QueryBuilder) getAggregatesSql(aggregators []QueryAggregator) (string, error) { selectList := []string{"COUNT(*) AS count"} for _, agg := range aggregators { - sql := fmt.Sprintf("%s AS %s", agg.aggregator(agg.column), agg.column.name) + sql, err := agg.AggregateSql() + if err != nil { + return "", err + } selectList = append(selectList, sql) } - return strings.Join(selectList, ", ") + return strings.Join(selectList, ", "), nil } func (qb *QueryBuilder) groupByOrderSql(order *model.Order) (string, error) { @@ -962,23 +957,6 @@ func (qb *QueryBuilder) groupByOrderSql(order *model.Order) (string, error) { return fmt.Sprintf("ORDER BY %s %s", col, order.Direction), nil } -func getAggregatorFn(aggregateType AggregateType) (aggregatorFn, error) { - switch aggregateType { - case Max: - return func(col *queryColumn) string { - return fmt.Sprintf("MAX(%s.%s)", col.abbrev, col.name) - }, nil - case Average: - return func(col *queryColumn) string { - return fmt.Sprintf("AVG(%s.%s)", col.abbrev, col.name) - }, nil - case Unknown: - return nil, errors.New("unknown aggregate type") - default: - return nil, errors.Errorf("cannot determine aggregate type: %v", aggregateType) - } -} - func (qb *QueryBuilder) getQueryColumn(col string, queryTables map[string]bool) (*queryColumn, error) { table, err := qb.highestPrecedenceTableForColumn(col, queryTables) if err != nil { diff --git a/internal/lookoutv2/repository/common_test.go b/internal/lookoutv2/repository/querybuilder_test.go similarity index 88% rename from internal/lookoutv2/repository/common_test.go rename to internal/lookoutv2/repository/querybuilder_test.go index 3fe2dd708c5..aa15d3b82c0 100644 --- a/internal/lookoutv2/repository/common_test.go +++ b/internal/lookoutv2/repository/querybuilder_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" ) @@ -446,6 +447,64 @@ func TestQueryBuilder_GroupByMultipleAggregates(t *testing.T) { assert.Equal(t, []interface{}{"test\\queue", "1234", "abcd", "test\\queue", "5678", "efgh%", "test\\queue", "anon\\\\one%"}, query.Args) } +func TestQueryBuilder_GroupByStateAggregates(t *testing.T) { + stateFilter := &model.Filter{ + Field: "state", + Match: model.MatchAnyOf, + Value: []string{ + string(lookout.JobQueued), + string(lookout.JobLeased), + string(lookout.JobPending), + string(lookout.JobRunning), + }, + } + query, err := NewQueryBuilder(NewTables()).GroupBy( + append(testFilters, stateFilter), + &model.Order{ + Direction: "DESC", + Field: "lastTransitionTime", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{ + "lastTransitionTime", + "submitted", + "state", + }, + 20, + 100, + ) + assert.NoError(t, err) + assert.Equal(t, splitByWhitespace(` + SELECT j.jobset, + COUNT(*) AS count, + AVG(j.last_transition_time_seconds) AS last_transition_time_seconds, + MAX(j.submitted) AS submitted, + SUM(CASE WHEN j.state = 1 THEN 1 ELSE 0 END) AS state_QUEUED, + SUM(CASE WHEN j.state = 8 THEN 1 ELSE 0 END) AS state_LEASED, + SUM(CASE WHEN j.state = 2 THEN 1 ELSE 0 END) AS state_PENDING, + SUM(CASE WHEN j.state = 3 THEN 1 ELSE 0 END) AS state_RUNNING + FROM job AS j + INNER JOIN ( + SELECT job_id + FROM user_annotation_lookup + WHERE queue = $1 AND key = $2 AND value = $3 + ) AS ual0 ON j.job_id = ual0.job_id + INNER JOIN ( + SELECT job_id + FROM user_annotation_lookup + WHERE queue = $4 AND key = $5 AND value LIKE $6 + ) AS ual1 ON j.job_id = ual1.job_id + WHERE j.queue = $7 AND j.owner LIKE $8 AND j.state IN ($9, $10, $11, $12) + GROUP BY j.jobset + ORDER BY last_transition_time_seconds DESC + LIMIT 100 OFFSET 20 + `), + splitByWhitespace(query.Sql)) + assert.Equal(t, []interface{}{"test\\queue", "1234", "abcd", "test\\queue", "5678", "efgh%", "test\\queue", "anon\\\\one%", 1, 8, 2, 3}, query.Args) +} + func TestQueryBuilder_GroupByAnnotationMultipleAggregates(t *testing.T) { query, err := NewQueryBuilder(NewTables()).GroupBy( testFilters, diff --git a/internal/lookoutv2/repository/tables.go b/internal/lookoutv2/repository/tables.go index 4633620ec31..779f53fc854 100644 --- a/internal/lookoutv2/repository/tables.go +++ b/internal/lookoutv2/repository/tables.go @@ -41,9 +41,10 @@ const ( type AggregateType int const ( - Unknown AggregateType = -1 - Max = 0 - Average = 1 + Unknown AggregateType = -1 + Max = 0 + Average = 1 + StateCounts = 2 ) type LookoutTables struct { @@ -134,6 +135,7 @@ func NewTables() *LookoutTables { groupAggregates: map[string]AggregateType{ submittedCol: Max, lastTransitionTimeCol: Average, + stateCol: StateCounts, }, } } diff --git a/internal/lookoutv2/repository/util.go b/internal/lookoutv2/repository/util.go index 2b7d8820e38..00af6da2b06 100644 --- a/internal/lookoutv2/repository/util.go +++ b/internal/lookoutv2/repository/util.go @@ -166,6 +166,30 @@ func (js *JobSimulator) Submit(queue, jobSet, owner string, timestamp time.Time, return js } +func (js *JobSimulator) Lease(runId string, timestamp time.Time) *JobSimulator { + ts := timestampOrNow(timestamp) + leasedEvent := &armadaevents.EventSequence_Event{ + Created: &ts, + Event: &armadaevents.EventSequence_Event_JobRunLeased{ + JobRunLeased: &armadaevents.JobRunLeased{ + RunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(runId)), + JobId: js.jobId, + }, + }, + } + js.events = append(js.events, leasedEvent) + + js.job.LastActiveRunId = &runId + js.job.LastTransitionTime = ts + js.job.State = string(lookout.JobLeased) + updateRun(js.job, &runPatch{ + runId: runId, + jobRunState: pointer.String(string(lookout.JobRunLeased)), + pending: &ts, + }) + return js +} + func (js *JobSimulator) Pending(runId string, cluster string, timestamp time.Time) *JobSimulator { ts := timestampOrNow(timestamp) assignedEvent := &armadaevents.EventSequence_Event{ @@ -417,6 +441,31 @@ func (js *JobSimulator) Failed(node string, exitCode int32, message string, time return js } +func (js *JobSimulator) Preempted(timestamp time.Time) *JobSimulator { + ts := timestampOrNow(timestamp) + jobIdProto, err := armadaevents.ProtoUuidFromUlidString(util.NewULID()) + if err != nil { + log.WithError(err).Errorf("Could not convert job ID to UUID: %s", util.NewULID()) + } + + preempted := &armadaevents.EventSequence_Event{ + Created: &ts, + Event: &armadaevents.EventSequence_Event_JobRunPreempted{ + JobRunPreempted: &armadaevents.JobRunPreempted{ + PreemptedJobId: js.jobId, + PreemptiveJobId: jobIdProto, + PreemptedRunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(uuid.NewString())), + PreemptiveRunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(uuid.NewString())), + }, + }, + } + js.events = append(js.events, preempted) + + js.job.LastTransitionTime = ts + js.job.State = string(lookout.JobPreempted) + return js +} + func (js *JobSimulator) RunTerminated(runId string, cluster string, node string, message string, timestamp time.Time) *JobSimulator { ts := timestampOrNow(timestamp) terminated := &armadaevents.EventSequence_Event{ diff --git a/internal/lookoutv2/swagger.yaml b/internal/lookoutv2/swagger.yaml index 1b81fffe86d..6a18a4dc1e9 100644 --- a/internal/lookoutv2/swagger.yaml +++ b/internal/lookoutv2/swagger.yaml @@ -178,7 +178,7 @@ definitions: aggregates: type: object additionalProperties: - type: string + type: object x-nullable: false filter: type: object diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index d0ded087588..2e869782731 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -24,16 +24,25 @@ import ( "github.com/armadaproject/armada/pkg/executorapi" ) -// ExecutorApi is a gRPC service that exposes functionality required by the armada executors +// ExecutorApi is the gRPC service executors use to synchronise their state with that of the scheduler. type ExecutorApi struct { - producer pulsar.Producer - jobRepository database.JobRepository - executorRepository database.ExecutorRepository - legacyExecutorRepository database.ExecutorRepository - allowedPriorities []int32 // allowed priority classes - maxJobsPerCall uint // maximum number of jobs that will be leased in a single call - maxPulsarMessageSize uint // maximum sizer of pulsar messages produced - nodeIdLabel string + // Used to send Pulsar messages when, e.g., executors report a job has finished. + producer pulsar.Producer + // Interface to the component storing job information, such as which jobs are leased to a particular executor. + jobRepository database.JobRepository + // Interface to the component storing executor information, such as which when we last heard from an executor. + executorRepository database.ExecutorRepository + // Like executorRepository + legacyExecutorRepository database.ExecutorRepository + // Allowed priority class priorities. + allowedPriorities []int32 + // Max number of job leases sent per call to LeaseJobRuns. + maxJobsPerCall uint + // Max size of Pulsar messages produced. + maxPulsarMessageSizeBytes uint + // See scheduling config. + nodeIdLabel string + // See scheduling config. priorityClassNameOverride *string clock clock.Clock } @@ -46,6 +55,7 @@ func NewExecutorApi(producer pulsar.Producer, maxJobsPerCall uint, nodeIdLabel string, priorityClassNameOverride *string, + maxPulsarMessageSizeBytes uint, ) (*ExecutorApi, error) { if len(allowedPriorities) == 0 { return nil, errors.New("allowedPriorities cannot be empty") @@ -60,60 +70,56 @@ func NewExecutorApi(producer pulsar.Producer, legacyExecutorRepository: legacyExecutorRepository, allowedPriorities: allowedPriorities, maxJobsPerCall: maxJobsPerCall, - maxPulsarMessageSize: 1024 * 1024 * 2, + maxPulsarMessageSizeBytes: maxPulsarMessageSizeBytes, nodeIdLabel: nodeIdLabel, priorityClassNameOverride: priorityClassNameOverride, clock: clock.RealClock{}, }, nil } -// LeaseJobRuns performs the following actions: -// - Stores the request in postgres so that the scheduler can use the job + capacity information in the next scheduling round -// - Determines if any of the job runs in the request are no longer active and should be cancelled -// - Determines if any new job runs should be leased to the executor +// LeaseJobRuns reconciles the state of the executor with that of the scheduler. Specifically it: +// 1. Stores job and capacity information received from the executor to make it available to the scheduler. +// 2. Notifies the executor if any of its jobs are no longer active, e.g., due to being preempted by the scheduler. +// 3. Transfers any jobs scheduled on this executor cluster that the executor don't already have. func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRunsServer) error { - ctx := stream.Context() - log := ctxlogrus.Extract(ctx) // Receive once to get info necessary to get jobs to lease. req, err := stream.Recv() if err != nil { return errors.WithStack(err) } - log.Infof("Handling lease request for executor %s", req.ExecutorId) + ctx := stream.Context() + log := ctxlogrus.Extract(ctx) + log = log.WithField("executor", req.ExecutorId) - // store the executor state for use by the scheduler - executorState := srv.createExecutorState(ctx, req) - if err = srv.executorRepository.StoreExecutor(stream.Context(), executorState); err != nil { + executor := srv.executorFromLeaseRequest(ctx, req) + if err := srv.executorRepository.StoreExecutor(ctx, executor); err != nil { return err } - - // store the executor state for the legacy executor to use - if err = srv.legacyExecutorRepository.StoreExecutor(stream.Context(), executorState); err != nil { + if err = srv.legacyExecutorRepository.StoreExecutor(ctx, executor); err != nil { return err } - requestRuns, err := extractRunIds(req) + requestRuns, err := runIdsFromLeaseRequest(req) if err != nil { return err } - log.Debugf("Executor is currently aware of %d job runs", len(requestRuns)) - - runsToCancel, err := srv.jobRepository.FindInactiveRuns(stream.Context(), requestRuns) + runsToCancel, err := srv.jobRepository.FindInactiveRuns(ctx, requestRuns) if err != nil { return err } - log.Debugf("Detected %d runs that need cancelling", len(runsToCancel)) - - // Fetch new leases from the db - leases, err := srv.jobRepository.FetchJobRunLeases(stream.Context(), req.ExecutorId, srv.maxJobsPerCall, requestRuns) + newRuns, err := srv.jobRepository.FetchJobRunLeases(ctx, req.ExecutorId, srv.maxJobsPerCall, requestRuns) if err != nil { return err } + log.Infof( + "executor currently has %d job runs; sending %d cancellations and %d new runs", + len(requestRuns), len(runsToCancel), len(newRuns), + ) - // if necessary send a list of runs to cancel + // Send any runs that should be cancelled. if len(runsToCancel) > 0 { - err = stream.Send(&executorapi.LeaseStreamMessage{ + if err := stream.Send(&executorapi.LeaseStreamMessage{ Event: &executorapi.LeaseStreamMessage_CancelRuns{ CancelRuns: &executorapi.CancelRuns{ JobRunIdsToCancel: util.Map(runsToCancel, func(x uuid.UUID) *armadaevents.Uuid { @@ -121,25 +127,22 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns }), }, }, - }) - - if err != nil { + }); err != nil { return errors.WithStack(err) } } - // Now send any leases + // Send any scheduled jobs the executor doesn't already have. decompressor := compress.NewZlibDecompressor() - for _, lease := range leases { + for _, lease := range newRuns { submitMsg := &armadaevents.SubmitJob{} - err = decompressAndMarshall(lease.SubmitMessage, decompressor, submitMsg) - if err != nil { + if err := unmarshalFromCompressedBytes(lease.SubmitMessage, decompressor, submitMsg); err != nil { return err } if srv.priorityClassNameOverride != nil { srv.setPriorityClassName(submitMsg, *srv.priorityClassNameOverride) } - srv.addNodeSelector(submitMsg, lease.Node) + srv.addNodeIdSelector(submitMsg, lease.Node) var groups []string if len(lease.Groups) > 0 { @@ -148,7 +151,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns return err } } - err = stream.Send(&executorapi.LeaseStreamMessage{ + err := stream.Send(&executorapi.LeaseStreamMessage{ Event: &executorapi.LeaseStreamMessage_Lease{ Lease: &executorapi.JobRunLease{ JobRunId: armadaevents.ProtoUuidFromUuid(lease.RunID), @@ -189,11 +192,10 @@ func (srv *ExecutorApi) setPriorityClassName(job *armadaevents.SubmitJob, priori } } -func (srv *ExecutorApi) addNodeSelector(job *armadaevents.SubmitJob, nodeId string) { +func (srv *ExecutorApi) addNodeIdSelector(job *armadaevents.SubmitJob, nodeId string) { if job == nil || nodeId == "" { return } - if job.MainObject != nil { switch typed := job.MainObject.Object.(type) { case *armadaevents.KubernetesMainObject_PodSpec: @@ -207,9 +209,10 @@ func addNodeSelector(podSpec *armadaevents.PodSpecWithAvoidList, key string, val return } if podSpec.PodSpec.NodeSelector == nil { - podSpec.PodSpec.NodeSelector = make(map[string]string, 1) + podSpec.PodSpec.NodeSelector = map[string]string{key: value} + } else { + podSpec.PodSpec.NodeSelector[key] = value } - podSpec.PodSpec.NodeSelector[key] = value } func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityClassName string) { @@ -219,19 +222,19 @@ func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityCl podSpec.PodSpec.PriorityClassName = priorityClassName } -// ReportEvents publishes all events to pulsar. The events are compacted for more efficient publishing +// ReportEvents publishes all events to Pulsar. The events are compacted for more efficient publishing. func (srv *ExecutorApi) ReportEvents(ctx context.Context, list *executorapi.EventList) (*types.Empty, error) { - err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSize, schedulers.Pulsar) + err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSizeBytes, schedulers.Pulsar) return &types.Empty{}, err } -// createExecutorState extracts a schedulerobjects.Executor from the requesrt -func (srv *ExecutorApi) createExecutorState(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { +// executorFromLeaseRequest extracts a schedulerobjects.Executor from the request. +func (srv *ExecutorApi) executorFromLeaseRequest(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { log := ctxlogrus.Extract(ctx) nodes := make([]*schedulerobjects.Node, 0, len(req.Nodes)) + now := srv.clock.Now().UTC() for _, nodeInfo := range req.Nodes { - node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, srv.clock.Now().UTC()) - if err != nil { + if node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, now); err != nil { logging.WithStacktrace(log, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetExecutorId(), ) @@ -244,37 +247,35 @@ func (srv *ExecutorApi) createExecutorState(ctx context.Context, req *executorap Pool: req.Pool, Nodes: nodes, MinimumJobSize: schedulerobjects.ResourceList{Resources: req.MinimumJobSize}, - LastUpdateTime: srv.clock.Now().UTC(), - UnassignedJobRuns: util.Map(req.UnassignedJobRunIds, func(x armadaevents.Uuid) string { - return strings.ToLower(armadaevents.UuidFromProtoUuid(&x).String()) + LastUpdateTime: now, + UnassignedJobRuns: util.Map(req.UnassignedJobRunIds, func(jobId armadaevents.Uuid) string { + return strings.ToLower(armadaevents.UuidFromProtoUuid(&jobId).String()) }), } } -// extractRunIds extracts all the job runs contained in the executor request -func extractRunIds(req *executorapi.LeaseRequest) ([]uuid.UUID, error) { - runIds := make([]uuid.UUID, 0) - // add all runids from nodes +// runIdsFromLeaseRequest returns the ids of all runs in a lease request, including any not yet assigned to a node. +func runIdsFromLeaseRequest(req *executorapi.LeaseRequest) ([]uuid.UUID, error) { + runIds := make([]uuid.UUID, 0, 256) for _, node := range req.Nodes { for runIdStr := range node.RunIdsByState { - runId, err := uuid.Parse(runIdStr) - if err != nil { + if runId, err := uuid.Parse(runIdStr); err != nil { return nil, errors.WithStack(err) + } else { + runIds = append(runIds, runId) } - runIds = append(runIds, runId) } } - // add all unassigned runids for _, runId := range req.UnassignedJobRunIds { runIds = append(runIds, armadaevents.UuidFromProtoUuid(&runId)) } return runIds, nil } -func decompressAndMarshall(b []byte, decompressor compress.Decompressor, msg proto.Message) error { - decompressed, err := decompressor.Decompress(b) +func unmarshalFromCompressedBytes(bytes []byte, decompressor compress.Decompressor, msg proto.Message) error { + decompressedBytes, err := decompressor.Decompress(bytes) if err != nil { return err } - return proto.Unmarshal(decompressed, msg) + return proto.Unmarshal(decompressedBytes, msg) } diff --git a/internal/scheduler/api_test.go b/internal/scheduler/api_test.go index e0e30371755..5587c8cfb96 100644 --- a/internal/scheduler/api_test.go +++ b/internal/scheduler/api_test.go @@ -171,7 +171,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { mockLegacyExecutorRepository := schedulermocks.NewMockExecutorRepository(ctrl) mockStream := schedulermocks.NewMockExecutorApi_LeaseJobRunsServer(ctrl) - runIds, err := extractRunIds(tc.request) + runIds, err := runIdsFromLeaseRequest(tc.request) require.NoError(t, err) // set up mocks @@ -204,6 +204,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { maxJobsPerCall, "kubernetes.io/hostname", nil, + 4*1024*1024, ) require.NoError(t, err) server.clock = testClock @@ -331,6 +332,7 @@ func TestExecutorApi_Publish(t *testing.T) { 100, "kubernetes.io/hostname", nil, + 4*1024*1024, ) require.NoError(t, err) diff --git a/internal/scheduler/common.go b/internal/scheduler/common.go index f8a29114936..2b8fdbe2abf 100644 --- a/internal/scheduler/common.go +++ b/internal/scheduler/common.go @@ -2,7 +2,6 @@ package scheduler import ( "fmt" - "math" "strconv" "time" @@ -84,11 +83,7 @@ func JobsSummary(jobs []interfaces.LegacySchedulerJob) string { func(jobs []interfaces.LegacySchedulerJob) schedulerobjects.ResourceList { rv := schedulerobjects.NewResourceListWithDefaultSize() for _, job := range jobs { - req := PodRequirementFromLegacySchedulerJob(job, nil) - if req == nil { - continue - } - rv.AddV1ResourceList(req.ResourceRequirements.Requests) + rv.AddV1ResourceList(job.GetResourceRequirements().Requests) } return rv }, @@ -141,29 +136,14 @@ func isEvictedJob(job interfaces.LegacySchedulerJob) bool { return job.GetAnnotations()[schedulerconfig.IsEvictedAnnotation] == "true" } -func targetNodeIdFromLegacySchedulerJob(job interfaces.LegacySchedulerJob) (string, bool) { - req := PodRequirementFromLegacySchedulerJob(job, nil) - if req == nil { - return "", false - } - nodeId, ok := req.NodeSelector[schedulerconfig.NodeIdLabel] +func targetNodeIdFromNodeSelector(nodeSelector map[string]string) (string, bool) { + nodeId, ok := nodeSelector[schedulerconfig.NodeIdLabel] return nodeId, ok } // GangIdAndCardinalityFromLegacySchedulerJob returns a tuple (gangId, gangCardinality, isGangJob, error). -func GangIdAndCardinalityFromLegacySchedulerJob(job interfaces.LegacySchedulerJob, priorityClasses map[string]configuration.PriorityClass) (string, int, bool, error) { - reqs := job.GetRequirements(priorityClasses) - if reqs == nil { - return "", 0, false, nil - } - if len(reqs.ObjectRequirements) != 1 { - return "", 0, false, errors.Errorf("expected exactly one object requirement in %v", reqs) - } - podReqs := reqs.ObjectRequirements[0].GetPodRequirements() - if podReqs == nil { - return "", 0, false, nil - } - return GangIdAndCardinalityFromAnnotations(podReqs.Annotations) +func GangIdAndCardinalityFromLegacySchedulerJob(job interfaces.LegacySchedulerJob) (string, int, bool, error) { + return GangIdAndCardinalityFromAnnotations(job.GetAnnotations()) } // GangIdAndCardinalityFromAnnotations returns a tuple (gangId, gangCardinality, isGangJob, error). @@ -189,17 +169,6 @@ func GangIdAndCardinalityFromAnnotations(annotations map[string]string) (string, return gangId, gangCardinality, true, nil } -// ResourceListAsWeightedMillis returns the linear combination of the milli values in rl with given weights. -// This function overflows for values that exceed MaxInt64. E.g., 1Pi is fine but not 10Pi. -func ResourceListAsWeightedMillis(weights map[string]float64, rl schedulerobjects.ResourceList) int64 { - var rv int64 - for t, f := range weights { - q := rl.Get(t) - rv += int64(math.Round(float64(q.MilliValue()) * f)) - } - return rv -} - func PodRequirementsFromLegacySchedulerJobs[S ~[]E, E interfaces.LegacySchedulerJob](jobs S, priorityClasses map[string]configuration.PriorityClass) []*schedulerobjects.PodRequirements { rv := make([]*schedulerobjects.PodRequirements, len(jobs)) for i, job := range jobs { @@ -222,20 +191,12 @@ func PodRequirementFromLegacySchedulerJob[E interfaces.LegacySchedulerJob](job E } annotations[schedulerconfig.JobIdAnnotation] = job.GetId() annotations[schedulerconfig.QueueAnnotation] = job.GetQueue() - info := job.GetRequirements(priorityClasses) + info := job.GetJobSchedulingInfo(priorityClasses) req := PodRequirementFromJobSchedulingInfo(info) req.Annotations = annotations return req } -func PodRequirementsFromJobSchedulingInfos(infos []*schedulerobjects.JobSchedulingInfo) []*schedulerobjects.PodRequirements { - rv := make([]*schedulerobjects.PodRequirements, 0, len(infos)) - for _, info := range infos { - rv = append(rv, PodRequirementFromJobSchedulingInfo(info)) - } - return rv -} - func PodRequirementFromJobSchedulingInfo(info *schedulerobjects.JobSchedulingInfo) *schedulerobjects.PodRequirements { for _, oreq := range info.ObjectRequirements { if preq := oreq.GetPodRequirements(); preq != nil { diff --git a/internal/scheduler/common_test.go b/internal/scheduler/common_test.go index e1a87d287c1..c71cd16513b 100644 --- a/internal/scheduler/common_test.go +++ b/internal/scheduler/common_test.go @@ -134,7 +134,7 @@ func TestResourceListAsWeightedMillis(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - assert.Equal(t, tc.expected, ResourceListAsWeightedMillis(tc.weights, tc.rl)) + assert.Equal(t, tc.expected, tc.rl.AsWeightedMillis(tc.weights)) }) } } @@ -151,6 +151,6 @@ func BenchmarkResourceListAsWeightedMillis(b *testing.B) { } b.ResetTimer() for n := 0; n < b.N; n++ { - ResourceListAsWeightedMillis(weights, rl) + rl.AsWeightedMillis(weights) } } diff --git a/internal/scheduler/context/context.go b/internal/scheduler/context/context.go index a34caac890c..85c93624345 100644 --- a/internal/scheduler/context/context.go +++ b/internal/scheduler/context/context.go @@ -10,7 +10,6 @@ import ( "github.com/pkg/errors" "golang.org/x/exp/maps" "golang.org/x/exp/slices" - v1 "k8s.io/api/core/v1" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common/armadaerrors" @@ -39,8 +38,12 @@ type SchedulingContext struct { ResourceScarcity map[string]float64 // Per-queue scheduling contexts. QueueSchedulingContexts map[string]*QueueSchedulingContext + // Sum of weights across all queues. + WeightSum float64 // Total resources across all clusters available at the start of the scheduling cycle. TotalResources schedulerobjects.ResourceList + // = TotalResources.AsWeightedMillis(ResourceScarcity). + TotalResourcesAsWeightedMillis int64 // Resources assigned across all queues during this scheduling cycle. ScheduledResources schedulerobjects.ResourceList ScheduledResourcesByPriorityClass schedulerobjects.QuantityByTAndResourceType[string] @@ -81,6 +84,7 @@ func NewSchedulingContext( ResourceScarcity: resourceScarcity, QueueSchedulingContexts: make(map[string]*QueueSchedulingContext), TotalResources: totalResources.DeepCopy(), + TotalResourcesAsWeightedMillis: totalResources.AsWeightedMillis(resourceScarcity), ScheduledResources: schedulerobjects.NewResourceListWithDefaultSize(), ScheduledResourcesByPriorityClass: make(schedulerobjects.QuantityByTAndResourceType[string]), EvictedResourcesByPriorityClass: make(schedulerobjects.QuantityByTAndResourceType[string]), @@ -107,7 +111,7 @@ func (sctx *SchedulingContext) ClearUnfeasibleSchedulingKeys() { sctx.UnfeasibleSchedulingKeys = make(map[schedulerobjects.SchedulingKey]*JobSchedulingContext) } -func (sctx *SchedulingContext) AddQueueSchedulingContext(queue string, priorityFactor float64, initialAllocatedByPriorityClass schedulerobjects.QuantityByTAndResourceType[string]) error { +func (sctx *SchedulingContext) AddQueueSchedulingContext(queue string, weight float64, initialAllocatedByPriorityClass schedulerobjects.QuantityByTAndResourceType[string]) error { if _, ok := sctx.QueueSchedulingContexts[queue]; ok { return errors.WithStack(&armadaerrors.ErrInvalidArgument{ Name: "queue", @@ -124,12 +128,13 @@ func (sctx *SchedulingContext) AddQueueSchedulingContext(queue string, priorityF for _, rl := range initialAllocatedByPriorityClass { allocated.Add(rl) } + sctx.WeightSum += weight qctx := &QueueSchedulingContext{ SchedulingContext: sctx, Created: time.Now(), ExecutorId: sctx.ExecutorId, Queue: queue, - PriorityFactor: priorityFactor, + Weight: weight, Allocated: allocated, AllocatedByPriorityClass: initialAllocatedByPriorityClass, ScheduledResourcesByPriorityClass: make(schedulerobjects.QuantityByTAndResourceType[string]), @@ -314,8 +319,8 @@ type QueueSchedulingContext struct { ExecutorId string // Queue name. Queue string - // These factors influence the fraction of resources assigned to each queue. - PriorityFactor float64 + // Determines the fair share of this queue relative to other queues. + Weight float64 // Total resources assigned to the queue across all clusters by priority class priority. // Includes jobs scheduled during this invocation of the scheduler. Allocated schedulerobjects.ResourceList @@ -461,13 +466,13 @@ func (qctx *QueueSchedulingContext) AddJobSchedulingContext(jctx *JobSchedulingC func (qctx *QueueSchedulingContext) EvictJob(job interfaces.LegacySchedulerJob) (bool, error) { jobId := job.GetId() - _, rl := priorityAndRequestsFromLegacySchedulerJob(job, qctx.SchedulingContext.PriorityClasses) if _, ok := qctx.UnsuccessfulJobSchedulingContexts[jobId]; ok { return false, errors.Errorf("failed evicting job %s from queue: job already marked unsuccessful", jobId) } if _, ok := qctx.EvictedJobsById[jobId]; ok { return false, errors.Errorf("failed evicting job %s from queue: job already marked evicted", jobId) } + rl := job.GetResourceRequirements().Requests _, scheduledInThisRound := qctx.SuccessfulJobSchedulingContexts[jobId] if scheduledInThisRound { qctx.ScheduledResourcesByPriorityClass.SubV1ResourceList(job.GetPriorityClassName(), rl) @@ -481,19 +486,6 @@ func (qctx *QueueSchedulingContext) EvictJob(job interfaces.LegacySchedulerJob) return scheduledInThisRound, nil } -// TODO: Remove? -func priorityAndRequestsFromLegacySchedulerJob(job interfaces.LegacySchedulerJob, priorityClasses map[string]configuration.PriorityClass) (int32, v1.ResourceList) { - req := job.GetRequirements(priorityClasses) - for _, r := range req.ObjectRequirements { - podReqs := r.GetPodRequirements() - if podReqs == nil { - continue - } - return podReqs.Priority, podReqs.ResourceRequirements.Requests - } - return 0, nil -} - // ClearJobSpecs zeroes out job specs to reduce memory usage. func (qctx *QueueSchedulingContext) ClearJobSpecs() { for _, jctx := range qctx.SuccessfulJobSchedulingContexts { @@ -504,6 +496,19 @@ func (qctx *QueueSchedulingContext) ClearJobSpecs() { } } +// FractionOfFairShare returns a number in [0, 1] indicating what fraction of its fair share this queue is allocated. +func (qctx *QueueSchedulingContext) FractionOfFairShare() float64 { + return qctx.FractionOfFairShareWithAllocation(qctx.Allocated) +} + +// FractionOfFairShareWithAllocation returns a number in [0, 1] indicating what +// fraction of its fair share this queue is allocated if the total allocation of this queue is given by allocated. +func (qctx *QueueSchedulingContext) FractionOfFairShareWithAllocation(allocated schedulerobjects.ResourceList) float64 { + fairShare := qctx.Weight / qctx.SchedulingContext.WeightSum + allocatedAsWeightedMillis := allocated.AsWeightedMillis(qctx.SchedulingContext.ResourceScarcity) + return (float64(allocatedAsWeightedMillis) / float64(qctx.SchedulingContext.TotalResourcesAsWeightedMillis)) / fairShare +} + type GangSchedulingContext struct { Created time.Time Queue string diff --git a/internal/scheduler/context/context_test.go b/internal/scheduler/context/context_test.go index 0c8704f7de9..e00a9a5d0cb 100644 --- a/internal/scheduler/context/context_test.go +++ b/internal/scheduler/context/context_test.go @@ -88,6 +88,6 @@ func testSmallCpuJobSchedulingContext(queue, priorityClassName string) *JobSched NumNodes: 1, JobId: job.GetId(), Job: job, - Req: job.GetRequirements(nil).ObjectRequirements[0].GetPodRequirements(), + Req: job.GetJobSchedulingInfo(nil).ObjectRequirements[0].GetPodRequirements(), } } diff --git a/internal/scheduler/interfaces/interfaces.go b/internal/scheduler/interfaces/interfaces.go index 7786fb995a0..409409f3bf0 100644 --- a/internal/scheduler/interfaces/interfaces.go +++ b/internal/scheduler/interfaces/interfaces.go @@ -1,6 +1,8 @@ package interfaces import ( + "time" + v1 "k8s.io/api/core/v1" "github.com/armadaproject/armada/internal/armada/configuration" @@ -12,24 +14,13 @@ type LegacySchedulerJob interface { GetId() string GetQueue() string GetJobSet() string + GetPerQueuePriority() uint32 + GetSubmitTime() time.Time GetAnnotations() map[string]string - GetRequirements(map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo + GetJobSchedulingInfo(map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo GetPriorityClassName() string GetNodeSelector() map[string]string GetAffinity() *v1.Affinity GetTolerations() []v1.Toleration GetResourceRequirements() v1.ResourceRequirements } - -func PodRequirementFromLegacySchedulerJob(job LegacySchedulerJob, priorityClasses map[string]configuration.PriorityClass) *schedulerobjects.PodRequirements { - schedulingInfo := job.GetRequirements(priorityClasses) - if schedulingInfo == nil { - return nil - } - for _, objectReq := range schedulingInfo.ObjectRequirements { - if req := objectReq.GetPodRequirements(); req != nil { - return req - } - } - return nil -} diff --git a/internal/scheduler/jobdb/job.go b/internal/scheduler/jobdb/job.go index aae2e08be0f..519fce9d495 100644 --- a/internal/scheduler/jobdb/job.go +++ b/internal/scheduler/jobdb/job.go @@ -13,25 +13,25 @@ import ( // Job is the scheduler-internal representation of a job. type Job struct { - // String representation of the job id + // String representation of the job id. id string // Name of the queue this job belongs to. queue string - // Jobset the job belongs to - // We store this as it's needed for sending job event messages + // Jobset the job belongs to. + // We store this as it's needed for sending job event messages. jobset string // Per-queue priority of this job. priority uint32 // Requested per queue priority of this job. - // This is used when syncing the postgres database with the scheduler-internal database + // This is used when syncing the postgres database with the scheduler-internal database. requestedPriority uint32 // Logical timestamp indicating the order in which jobs are submitted. // Jobs with identical Queue and Priority are sorted by this. created int64 // True if the job is currently queued. - // If this is set then the job will not be considered for scheduling + // If this is set then the job will not be considered for scheduling. queued bool - // The current version of the queued state + // The current version of the queued state. queuedVersion int32 // Scheduling requirements of this job. jobSchedulingInfo *schedulerobjects.JobSchedulingInfo @@ -71,6 +71,20 @@ func NewJob( cancelled bool, created int64, ) *Job { + // Initialise the annotation and nodeSelector maps if nil. + // Since those need to be mutated in-place. + if schedulingInfo != nil { + for _, req := range schedulingInfo.ObjectRequirements { + if podReq := req.GetPodRequirements(); podReq != nil { + if podReq.Annotations == nil { + podReq.Annotations = make(map[string]string) + } + if podReq.NodeSelector == nil { + podReq.NodeSelector = make(map[string]string) + } + } + } + } return &Job{ id: jobId, jobset: jobset, @@ -126,6 +140,19 @@ func (job *Job) Priority() uint32 { return job.priority } +// GetPerQueuePriority exists for compatibility with the LegacyJob interface. +func (job *Job) GetPerQueuePriority() uint32 { + return job.priority +} + +// GetSubmitTime exists for compatibility with the LegacyJob interface. +func (job *Job) GetSubmitTime() time.Time { + if job.jobSchedulingInfo == nil { + return time.Time{} + } + return job.jobSchedulingInfo.SubmitTime +} + // RequestedPriority returns the requested priority of the job. func (job *Job) RequestedPriority() uint32 { return job.requestedPriority @@ -161,7 +188,7 @@ func (job *Job) GetAnnotations() map[string]string { // GetRequirements returns the scheduling requirements associated with the job. // Needed for compatibility with interfaces.LegacySchedulerJob -func (job *Job) GetRequirements(_ map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo { +func (job *Job) GetJobSchedulingInfo(_ map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo { return job.JobSchedulingInfo() } diff --git a/internal/scheduler/jobdb/job_test.go b/internal/scheduler/jobdb/job_test.go index 349cb0d998a..8ab6c17a835 100644 --- a/internal/scheduler/jobdb/job_test.go +++ b/internal/scheduler/jobdb/job_test.go @@ -54,7 +54,7 @@ func TestJob_TestGetter(t *testing.T) { assert.Equal(t, baseJob.queue, baseJob.Queue()) assert.Equal(t, baseJob.queue, baseJob.GetQueue()) assert.Equal(t, baseJob.created, baseJob.Created()) - assert.Equal(t, schedulingInfo, baseJob.GetRequirements(nil)) + assert.Equal(t, schedulingInfo, baseJob.GetJobSchedulingInfo(nil)) assert.Equal(t, schedulingInfo, baseJob.JobSchedulingInfo()) assert.Equal(t, baseJob.GetAnnotations(), map[string]string{ "foo": "bar", diff --git a/internal/scheduler/jobiteration.go b/internal/scheduler/jobiteration.go index d9b30f434cf..202bc38bd45 100644 --- a/internal/scheduler/jobiteration.go +++ b/internal/scheduler/jobiteration.go @@ -88,23 +88,23 @@ func (repo *InMemoryJobRepository) Enqueue(job interfaces.LegacySchedulerJob) { // finally by submit time, with earlier submit times first. func (repo *InMemoryJobRepository) sortQueue(queue string) { slices.SortFunc(repo.jobsByQueue[queue], func(a, b interfaces.LegacySchedulerJob) bool { - infoa := a.GetRequirements(repo.priorityClasses) - infob := b.GetRequirements(repo.priorityClasses) if repo.sortByPriorityClass { - pca := repo.priorityClasses[infoa.PriorityClassName] - pcb := repo.priorityClasses[infob.PriorityClassName] + pca := repo.priorityClasses[a.GetPriorityClassName()] + pcb := repo.priorityClasses[b.GetPriorityClassName()] if pca.Priority > pcb.Priority { return true } else if pca.Priority < pcb.Priority { return false } } - if infoa.GetPriority() < infob.GetPriority() { + pa := a.GetPerQueuePriority() + pb := b.GetPerQueuePriority() + if pa < pb { return true - } else if infoa.GetPriority() > infob.GetPriority() { + } else if pa > pb { return false } - return infoa.GetSubmitTime().Before(infob.GetSubmitTime()) + return a.GetSubmitTime().Before(b.GetSubmitTime()) }) } diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index b41a7a6b361..636d8caf713 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -32,6 +32,7 @@ type PreemptingQueueScheduler struct { constraints schedulerconstraints.SchedulingConstraints nodeEvictionProbability float64 nodeOversubscriptionEvictionProbability float64 + protectedFractionOfFairShare float64 jobRepo JobRepository nodeDb *nodedb.NodeDb // Maps job ids to the id of the node the job is associated with. @@ -53,6 +54,7 @@ func NewPreemptingQueueScheduler( constraints schedulerconstraints.SchedulingConstraints, nodeEvictionProbability float64, nodeOversubscriptionEvictionProbability float64, + protectedFractionOfFairShare float64, jobRepo JobRepository, nodeDb *nodedb.NodeDb, initialNodeIdByJobId map[string]string, @@ -77,6 +79,7 @@ func NewPreemptingQueueScheduler( constraints: constraints, nodeEvictionProbability: nodeEvictionProbability, nodeOversubscriptionEvictionProbability: nodeOversubscriptionEvictionProbability, + protectedFractionOfFairShare: protectedFractionOfFairShare, jobRepo: jobRepo, nodeDb: nodeDb, nodeIdByJobId: maps.Clone(initialNodeIdByJobId), @@ -99,7 +102,7 @@ func (sch *PreemptingQueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { log := ctxlogrus.Extract(ctx) log = log.WithField("service", "PreemptingQueueScheduler") - if ResourceListAsWeightedMillis(sch.schedulingContext.ResourceScarcity, sch.schedulingContext.TotalResources) == 0 { + if sch.schedulingContext.TotalResources.AsWeightedMillis(sch.schedulingContext.ResourceScarcity) == 0 { // This refers to resources available across all clusters, i.e., // it may include resources not currently considered for scheduling. log.Infof( @@ -108,7 +111,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe ) return &SchedulerResult{}, nil } - if ResourceListAsWeightedMillis(sch.schedulingContext.ResourceScarcity, sch.nodeDb.TotalResources()) == 0 { + if rl := sch.nodeDb.TotalResources(); rl.AsWeightedMillis(sch.schedulingContext.ResourceScarcity) == 0 { // This refers to the resources currently considered for scheduling. log.Infof( "no resources with non-zero weight available for scheduling in NodeDb: resource scarcity %v, total resources %v", @@ -137,11 +140,31 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe ctx, log.WithField("stage", "evict for resource balancing"), ), - NewStochasticEvictor( + NewNodeEvictor( sch.jobRepo, sch.schedulingContext.PriorityClasses, - sch.schedulingContext.DefaultPriorityClass, sch.nodeEvictionProbability, + func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { + if job.GetAnnotations() == nil { + log := ctxlogrus.Extract(ctx) + log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) + return false + } + if job.GetNodeSelector() == nil { + log := ctxlogrus.Extract(ctx) + log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) + return false + } + if qctx, ok := sch.schedulingContext.QueueSchedulingContexts[job.GetQueue()]; ok { + if qctx.FractionOfFairShare() <= sch.protectedFractionOfFairShare { + return false + } + } + if priorityClass, ok := sch.schedulingContext.PriorityClasses[job.GetPriorityClassName()]; ok { + return priorityClass.Preemptible + } + return false + }, nil, ), ) @@ -432,27 +455,24 @@ func (sch *PreemptingQueueScheduler) evictionAssertions(evictedJobsById map[stri } evictedJobIdsByGangId := make(map[string]map[string]bool) for _, job := range evictedJobsById { - if gangId, ok := sch.gangIdByJobId[job.GetId()]; ok { + jobId := job.GetId() + if gangId, ok := sch.gangIdByJobId[jobId]; ok { if m := evictedJobIdsByGangId[gangId]; m != nil { - m[job.GetId()] = true + m[jobId] = true } else { - evictedJobIdsByGangId[gangId] = map[string]bool{job.GetId(): true} + evictedJobIdsByGangId[gangId] = map[string]bool{jobId: true} } } if !isEvictedJob(job) { - return errors.Errorf("evicted job %s is not marked as such: job annotations %v", job.GetId(), job.GetAnnotations()) + return errors.Errorf("evicted job %s is not marked as such: job annotations %v", jobId, job.GetAnnotations()) } - if nodeId, ok := targetNodeIdFromLegacySchedulerJob(job); ok { + nodeSelector := job.GetNodeSelector() + if nodeId, ok := targetNodeIdFromNodeSelector(nodeSelector); ok { if _, ok := affectedNodesById[nodeId]; !ok { - return errors.Errorf("node id %s targeted by job %s is not marked as affected", nodeId, job.GetId()) + return errors.Errorf("node id %s targeted by job %s is not marked as affected", nodeId, jobId) } } else { - req := PodRequirementFromLegacySchedulerJob(job, nil) - if req != nil { - return errors.Errorf("evicted job %s is missing target node id selector: job nodeSelector %v", job.GetId(), req.NodeSelector) - } else { - return errors.Errorf("evicted job %s is missing target node id selector: req is nil", job.GetId()) - } + return errors.Errorf("evicted job %s is missing target node id selector: job nodeSelector %v", jobId, nodeSelector) } } for gangId, evictedGangJobIds := range evictedJobIdsByGangId { @@ -552,7 +572,7 @@ func (sch *PreemptingQueueScheduler) updateGangAccounting(preemptedJobs, schedul } } for _, job := range scheduledJobs { - gangId, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(job, sch.schedulingContext.PriorityClasses) + gangId, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(job) if err != nil { return err } @@ -658,13 +678,11 @@ type EvictorResult struct { NodeIdByJobId map[string]string } -// NewStochasticEvictor returns a new evictor that for each node evicts -// all preemptible jobs from that node with probability perNodeEvictionProbability. -func NewStochasticEvictor( +func NewNodeEvictor( jobRepo JobRepository, priorityClasses map[string]configuration.PriorityClass, - defaultPriorityClass string, perNodeEvictionProbability float64, + jobFilter func(context.Context, interfaces.LegacySchedulerJob) bool, random *rand.Rand, ) *Evictor { if perNodeEvictionProbability <= 0 { @@ -673,44 +691,13 @@ func NewStochasticEvictor( if random == nil { random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) } - return NewPreemptibleEvictor( - jobRepo, - priorityClasses, - defaultPriorityClass, - func(_ context.Context, node *schedulerobjects.Node) bool { - return len(node.AllocatedByJobId) > 0 && random.Float64() < perNodeEvictionProbability - }, - ) -} - -// NewPreemptibleEvictor returns a new evictor that evicts all preemptible jobs -// on nodes for which nodeFilter returns true. -func NewPreemptibleEvictor( - jobRepo JobRepository, - priorityClasses map[string]configuration.PriorityClass, - defaultPriorityClass string, - nodeFilter func(context.Context, *schedulerobjects.Node) bool, -) *Evictor { return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: nodeFilter, - jobFilter: func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { - if job.GetAnnotations() == nil { - log := ctxlogrus.Extract(ctx) - log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) - return false - } - priorityClassName := job.GetRequirements(priorityClasses).PriorityClassName - priorityClass, ok := priorityClasses[priorityClassName] - if !ok { - priorityClass = priorityClasses[defaultPriorityClass] - } - if priorityClass.Preemptible { - return true - } - return false + nodeFilter: func(_ context.Context, node *schedulerobjects.Node) bool { + return len(node.AllocatedByJobId) > 0 && random.Float64() < perNodeEvictionProbability }, + jobFilter: jobFilter, postEvictFunc: defaultPostEvictFunc, } } @@ -786,7 +773,7 @@ func NewOversubscribedEvictor( log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) return false } - priorityClassName := job.GetRequirements(priorityClasses).PriorityClassName + priorityClassName := job.GetPriorityClassName() priorityClass, ok := priorityClasses[priorityClassName] if !ok { priorityClass = priorityClasses[defaultPriorityClass] @@ -864,22 +851,11 @@ func defaultPostEvictFunc(ctx context.Context, job interfaces.LegacySchedulerJob } // Add node selector ensuring this job is only re-scheduled onto the node it was evicted from. - req := PodRequirementFromLegacySchedulerJob(job, nil) - if req.NodeSelector == nil { + nodeSelector := job.GetNodeSelector() + if nodeSelector == nil { log := ctxlogrus.Extract(ctx) log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) } else { - req.NodeSelector[schedulerconfig.NodeIdLabel] = node.Id + nodeSelector[schedulerconfig.NodeIdLabel] = node.Id } - - // Add a toleration to allow the job to be re-scheduled even if node is unschedulable. - // - // TODO: Because req is created with a new tolerations slice above, this toleration doesn't persist. - // In practice, this isn't an issue now since we don't check static requirements for evicted jobs. - if node.Unschedulable { - req.Tolerations = append(req.Tolerations, nodedb.UnschedulableToleration()) - } - - // We've changed the scheduling requirements and must clear any cached key. - req.ClearCachedSchedulingKey() } diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index b2151e7e9b6..7ab6ae1d3fb 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -1136,6 +1136,122 @@ func TestPreemptingQueueScheduler(t *testing.T) { "B": 1, }, }, + "ProtectedFractionOfFairShare": { + SchedulingConfig: testfixtures.WithProtectedFractionOfFairShareConfig( + 1.0, + testfixtures.TestSchedulingConfig(), + ), + Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), + Rounds: []SchedulingRound{ + { + JobsByQueue: map[string][]*jobdb.Job{ + "A": testfixtures.N1CpuJobs("A", testfixtures.PriorityClass0, 10), + }, + ExpectedScheduledIndices: map[string][]int{ + "A": testfixtures.IntRange(0, 9), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "B": testfixtures.N1CpuJobs("B", testfixtures.PriorityClass3, 22), + }, + ExpectedScheduledIndices: map[string][]int{ + "B": testfixtures.IntRange(0, 21), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "C": testfixtures.N1CpuJobs("C", testfixtures.PriorityClass0, 1), + }, + }, + {}, // Empty round to make sure nothing changes. + }, + PriorityFactorByQueue: map[string]float64{ + "A": 1, + "B": 1, + "C": 1, + }, + }, + "ProtectedFractionOfFairShare at limit": { + SchedulingConfig: testfixtures.WithProtectedFractionOfFairShareConfig( + 0.5, + testfixtures.TestSchedulingConfig(), + ), + Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), + Rounds: []SchedulingRound{ + { + JobsByQueue: map[string][]*jobdb.Job{ + "A": testfixtures.N1CpuJobs("A", testfixtures.PriorityClass0, 8), + }, + ExpectedScheduledIndices: map[string][]int{ + "A": testfixtures.IntRange(0, 7), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "B": testfixtures.N1CpuJobs("B", testfixtures.PriorityClass3, 24), + }, + ExpectedScheduledIndices: map[string][]int{ + "B": testfixtures.IntRange(0, 23), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "C": testfixtures.N1CpuJobs("C", testfixtures.PriorityClass0, 1), + }, + }, + {}, // Empty round to make sure nothing changes. + }, + PriorityFactorByQueue: map[string]float64{ + "A": 0.5, + "B": 1, + "C": 1, + }, + }, + "ProtectedFractionOfFairShare above limit": { + SchedulingConfig: testfixtures.WithProtectedFractionOfFairShareConfig( + 0.5, + testfixtures.TestSchedulingConfig(), + ), + Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), + Rounds: []SchedulingRound{ + { + JobsByQueue: map[string][]*jobdb.Job{ + "A": testfixtures.N1CpuJobs("A", testfixtures.PriorityClass0, 9), + }, + ExpectedScheduledIndices: map[string][]int{ + "A": testfixtures.IntRange(0, 8), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "B": testfixtures.N1CpuJobs("B", testfixtures.PriorityClass3, 23), + }, + ExpectedScheduledIndices: map[string][]int{ + "B": testfixtures.IntRange(0, 22), + }, + }, + { + JobsByQueue: map[string][]*jobdb.Job{ + "C": testfixtures.N1CpuJobs("C", testfixtures.PriorityClass0, 1), + }, + ExpectedScheduledIndices: map[string][]int{ + "C": testfixtures.IntRange(0, 0), + }, + ExpectedPreemptedIndices: map[string]map[int][]int{ + "A": { + 0: testfixtures.IntRange(8, 8), + }, + }, + }, + {}, // Empty round to make sure nothing changes. + }, + PriorityFactorByQueue: map[string]float64{ + "A": 1, + "B": 1, + "C": 1, + }, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { @@ -1221,7 +1337,8 @@ func TestPreemptingQueueScheduler(t *testing.T) { tc.TotalResources, ) for queue, priorityFactor := range tc.PriorityFactorByQueue { - err := sctx.AddQueueSchedulingContext(queue, priorityFactor, allocatedByQueueAndPriorityClass[queue]) + weight := 1 / priorityFactor + err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByQueueAndPriorityClass[queue]) require.NoError(t, err) } constraints := schedulerconstraints.SchedulingConstraintsFromSchedulingConfig( @@ -1235,6 +1352,7 @@ func TestPreemptingQueueScheduler(t *testing.T) { constraints, tc.SchedulingConfig.Preemption.NodeEvictionProbability, tc.SchedulingConfig.Preemption.NodeOversubscriptionEvictionProbability, + tc.SchedulingConfig.Preemption.ProtectedFractionOfFairShare, repo, nodeDb, nodeIdByJobId, @@ -1379,13 +1497,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MaxPriorityFactor int }{ "1 node 1 queue 320 jobs": { - SchedulingConfig: testfixtures.WithNodeOversubscriptionEvictionProbabilityConfig( - 0, - testfixtures.WithNodeEvictionProbabilityConfig( - 0.1, - testfixtures.TestSchedulingConfig(), - ), - ), + SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), JobFunc: testfixtures.N1CpuJobs, NumQueues: 1, @@ -1393,11 +1505,17 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MinPriorityFactor: 1, MaxPriorityFactor: 1, }, + "1 node 10 queues 320 jobs": { + SchedulingConfig: testfixtures.TestSchedulingConfig(), + Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), + JobFunc: testfixtures.N1CpuJobs, + NumQueues: 10, + NumJobsPerQueue: 320, + MinPriorityFactor: 1, + MaxPriorityFactor: 1, + }, "10 nodes 1 queue 3200 jobs": { - SchedulingConfig: testfixtures.WithNodeEvictionProbabilityConfig( - 0.1, - testfixtures.TestSchedulingConfig(), - ), + SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(10, testfixtures.TestPriorities), JobFunc: testfixtures.N1CpuJobs, NumQueues: 1, @@ -1406,10 +1524,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MaxPriorityFactor: 1, }, "10 nodes 10 queues 3200 jobs": { - SchedulingConfig: testfixtures.WithNodeEvictionProbabilityConfig( - 0.1, - testfixtures.TestSchedulingConfig(), - ), + SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(10, testfixtures.TestPriorities), JobFunc: testfixtures.N1CpuJobs, NumQueues: 10, @@ -1418,10 +1533,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MaxPriorityFactor: 1, }, "100 nodes 1 queue 32000 jobs": { - SchedulingConfig: testfixtures.WithNodeEvictionProbabilityConfig( - 0.1, - testfixtures.TestSchedulingConfig(), - ), + SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(100, testfixtures.TestPriorities), JobFunc: testfixtures.N1CpuJobs, NumQueues: 1, @@ -1429,11 +1541,17 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MinPriorityFactor: 1, MaxPriorityFactor: 1, }, + "100 nodes 10 queues 32000 jobs": { + SchedulingConfig: testfixtures.TestSchedulingConfig(), + Nodes: testfixtures.N32CpuNodes(100, testfixtures.TestPriorities), + JobFunc: testfixtures.N1CpuJobs, + NumQueues: 10, + NumJobsPerQueue: 32000, + MinPriorityFactor: 1, + MaxPriorityFactor: 1, + }, "1000 nodes 1 queue 320000 jobs": { - SchedulingConfig: testfixtures.WithNodeEvictionProbabilityConfig( - 0.1, - testfixtures.TestSchedulingConfig(), - ), + SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(1000, testfixtures.TestPriorities), JobFunc: testfixtures.N1CpuJobs, NumQueues: 1, @@ -1441,6 +1559,15 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { MinPriorityFactor: 1, MaxPriorityFactor: 1, }, + "1000 nodes 10 queues 320000 jobs": { + SchedulingConfig: testfixtures.TestSchedulingConfig(), + Nodes: testfixtures.N32CpuNodes(1000, testfixtures.TestPriorities), + JobFunc: testfixtures.N1CpuJobs, + NumQueues: 1, + NumJobsPerQueue: 32000, + MinPriorityFactor: 1, + MaxPriorityFactor: 1, + }, } for name, tc := range tests { b.Run(name, func(b *testing.B) { @@ -1454,8 +1581,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nodeDb, err := CreateNodeDb(tc.Nodes) require.NoError(b, err) - repo := NewInMemoryJobRepository(testfixtures.TestPriorityClasses) - allocatedByQueueAndPriorityClass := make(map[string]schedulerobjects.QuantityByTAndResourceType[string]) + jobRepo := NewInMemoryJobRepository(testfixtures.TestPriorityClasses) jobs := make([]interfaces.LegacySchedulerJob, 0) for _, queueJobs := range jobsByQueue { @@ -1463,7 +1589,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { jobs = append(jobs, job) } } - repo.EnqueueMany(jobs) + jobRepo.EnqueueMany(jobs) sctx := schedulercontext.NewSchedulingContext( "executor", @@ -1474,7 +1600,8 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nodeDb.TotalResources(), ) for queue, priorityFactor := range priorityFactorByQueue { - err := sctx.AddQueueSchedulingContext(queue, priorityFactor, allocatedByQueueAndPriorityClass[queue]) + weight := 1 / priorityFactor + err := sctx.AddQueueSchedulingContext(queue, weight, make(schedulerobjects.QuantityByTAndResourceType[string])) require.NoError(b, err) } constraints := schedulerconstraints.SchedulingConstraintsFromSchedulingConfig( @@ -1488,7 +1615,8 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { constraints, tc.SchedulingConfig.Preemption.NodeEvictionProbability, tc.SchedulingConfig.Preemption.NodeOversubscriptionEvictionProbability, - repo, + tc.SchedulingConfig.Preemption.ProtectedFractionOfFairShare, + jobRepo, nodeDb, nil, nil, @@ -1498,19 +1626,30 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { require.NoError(b, err) require.Equal(b, 0, len(result.PreemptedJobs)) - // Create a new job repo without the scheduled jobs. - scheduledJobsById := make(map[string]interfaces.LegacySchedulerJob) + scheduledJobs := make(map[string]bool) for _, job := range result.ScheduledJobs { - scheduledJobsById[job.GetId()] = job + scheduledJobs[job.GetId()] = true } - unscheduledJobs := make([]interfaces.LegacySchedulerJob, 0) - for _, job := range jobs { - if _, ok := scheduledJobsById[job.GetId()]; !ok { - unscheduledJobs = append(unscheduledJobs, job) - } + for queue, jobs := range jobRepo.jobsByQueue { + jobRepo.jobsByQueue[queue] = armadaslices.Filter(jobs, func(job interfaces.LegacySchedulerJob) bool { return scheduledJobs[job.GetId()] }) } - repo = NewInMemoryJobRepository(testfixtures.TestPriorityClasses) - repo.EnqueueMany(unscheduledJobs) + + nodesById := make(map[string]*schedulerobjects.Node) + for _, node := range tc.Nodes { + nodesById[node.Id] = node + } + for _, job := range result.ScheduledJobs { + nodeId := result.NodeIdByJobId[job.GetId()] + node := nodesById[nodeId] + podRequirements := PodRequirementFromLegacySchedulerJob(job, tc.SchedulingConfig.Preemption.PriorityClasses) + node, err = nodedb.BindPodToNode(podRequirements, node) + require.NoError(b, err) + nodesById[nodeId] = node + } + nodeDb, err = CreateNodeDb(maps.Values(nodesById)) + require.NoError(b, err) + + allocatedByQueueAndPriorityClass := sctx.AllocatedByQueueAndPriority() b.ResetTimer() for n := 0; n < b.N; n++ { @@ -1523,7 +1662,8 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nodeDb.TotalResources(), ) for queue, priorityFactor := range priorityFactorByQueue { - err := sctx.AddQueueSchedulingContext(queue, priorityFactor, allocatedByQueueAndPriorityClass[queue]) + weight := 1 / priorityFactor + err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByQueueAndPriorityClass[queue]) require.NoError(b, err) } sch := NewPreemptingQueueScheduler( @@ -1531,7 +1671,8 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { constraints, tc.SchedulingConfig.Preemption.NodeEvictionProbability, tc.SchedulingConfig.Preemption.NodeOversubscriptionEvictionProbability, - repo, + tc.SchedulingConfig.Preemption.ProtectedFractionOfFairShare, + jobRepo, nodeDb, nil, nil, diff --git a/internal/scheduler/queue_scheduler.go b/internal/scheduler/queue_scheduler.go index 9408015ec3c..55883bb2158 100644 --- a/internal/scheduler/queue_scheduler.go +++ b/internal/scheduler/queue_scheduler.go @@ -3,7 +3,6 @@ package scheduler import ( "container/heap" "context" - "math" "reflect" "time" @@ -63,7 +62,7 @@ func (sch *QueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { func (sch *QueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { log := ctxlogrus.Extract(ctx) - if ResourceListAsWeightedMillis(sch.schedulingContext.ResourceScarcity, sch.schedulingContext.TotalResources) == 0 { + if sch.schedulingContext.TotalResources.AsWeightedMillis(sch.schedulingContext.ResourceScarcity) == 0 { // This refers to resources available across all clusters, i.e., // it may include resources not currently considered for scheduling. log.Infof( @@ -72,8 +71,8 @@ func (sch *QueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, erro ) return &SchedulerResult{}, nil } - if ResourceListAsWeightedMillis(sch.schedulingContext.ResourceScarcity, sch.gangScheduler.nodeDb.TotalResources()) == 0 { - // This refers to the resources currently considered for schedling. + if rl := sch.gangScheduler.nodeDb.TotalResources(); rl.AsWeightedMillis(sch.schedulingContext.ResourceScarcity) == 0 { + // This refers to the resources currently considered for scheduling. log.Infof( "no resources with non-zero weight available for scheduling in NodeDb: resource scarcity %v, total resources %v", sch.schedulingContext.ResourceScarcity, sch.gangScheduler.nodeDb.TotalResources(), @@ -277,12 +276,6 @@ type CandidateGangIterator struct { SchedulingContext *schedulercontext.SchedulingContext // If true, this iterator only yields gangs where all jobs are evicted. onlyYieldEvicted bool - // For each queue, weight is the inverse of the priority factor. - weightByQueue map[string]float64 - // Sum of all weights. - weightSum float64 - // Total weighted resources. - totalResourcesAsWeightedMillis int64 // Reusable buffer to avoid allocations. buffer schedulerobjects.ResourceList // Priority queue containing per-queue iterators. @@ -294,28 +287,10 @@ func NewCandidateGangIterator( sctx *schedulercontext.SchedulingContext, iteratorsByQueue map[string]*QueuedGangIterator, ) (*CandidateGangIterator, error) { - weightSum := 0.0 - weightByQueue := make(map[string]float64, len(iteratorsByQueue)) - for queue := range iteratorsByQueue { - qctx := sctx.QueueSchedulingContexts[queue] - if qctx == nil { - return nil, errors.Errorf("no scheduling context for queue %s", queue) - } - weight := 1 / math.Max(qctx.PriorityFactor, 1) - weightByQueue[queue] = weight - weightSum += weight - } - totalResourcesAsWeightedMillis := ResourceListAsWeightedMillis(sctx.ResourceScarcity, sctx.TotalResources) - if totalResourcesAsWeightedMillis < 1 { - totalResourcesAsWeightedMillis = 1 - } it := &CandidateGangIterator{ - SchedulingContext: sctx, - weightByQueue: weightByQueue, - weightSum: weightSum, - totalResourcesAsWeightedMillis: totalResourcesAsWeightedMillis, - buffer: schedulerobjects.NewResourceListWithDefaultSize(), - pq: make(QueueCandidateGangIteratorPQ, 0, len(iteratorsByQueue)), + SchedulingContext: sctx, + buffer: schedulerobjects.NewResourceListWithDefaultSize(), + pq: make(QueueCandidateGangIteratorPQ, 0, len(iteratorsByQueue)), } for queue, queueIt := range iteratorsByQueue { if _, err := it.updateAndPushPQItem(it.newPQItem(queue, queueIt)); err != nil { @@ -372,17 +347,11 @@ func (it *CandidateGangIterator) updatePQItem(item *QueueCandidateGangIteratorIt // fractionOfFairShareWithGctx returns the fraction of its fair share this queue would have if the jobs in gctx were scheduled. func (it *CandidateGangIterator) fractionOfFairShareWithGctx(gctx *schedulercontext.GangSchedulingContext) float64 { + qctx := it.SchedulingContext.QueueSchedulingContexts[gctx.Queue] it.buffer.Zero() - it.buffer.Add(it.SchedulingContext.QueueSchedulingContexts[gctx.Queue].Allocated) + it.buffer.Add(qctx.Allocated) it.buffer.Add(gctx.TotalResourceRequests) - queueWeight := it.weightByQueue[gctx.Queue] - if queueWeight == 0 { - return 1 - } else { - fairShare := queueWeight / it.weightSum - used := ResourceListAsWeightedMillis(it.SchedulingContext.ResourceScarcity, it.buffer) - return (float64(used) / float64(it.totalResourcesAsWeightedMillis)) / fairShare - } + return qctx.FractionOfFairShareWithAllocation(it.buffer) } // Clear removes the first item in the iterator. diff --git a/internal/scheduler/queue_scheduler_test.go b/internal/scheduler/queue_scheduler_test.go index 9980c402423..0c0348179b2 100644 --- a/internal/scheduler/queue_scheduler_test.go +++ b/internal/scheduler/queue_scheduler_test.go @@ -453,7 +453,8 @@ func TestQueueScheduler(t *testing.T) { tc.TotalResources, ) for queue, priorityFactor := range tc.PriorityFactorByQueue { - err := sctx.AddQueueSchedulingContext(queue, priorityFactor, tc.InitialAllocatedByQueueAndPriorityClass[queue]) + weight := 1 / priorityFactor + err := sctx.AddQueueSchedulingContext(queue, weight, tc.InitialAllocatedByQueueAndPriorityClass[queue]) require.NoError(t, err) } constraints := schedulerconstraints.SchedulingConstraintsFromSchedulingConfig( @@ -577,7 +578,7 @@ func TestQueueScheduler(t *testing.T) { continue } assert.Equal(t, nodeDb.NumNodes(), pctx.NumNodes) - _, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(jctx.Job, nil) + _, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(jctx.Job) require.NoError(t, err) if !isGangJob { numExcludedNodes := 0 diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 24a44336a99..3a3459506a1 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -209,8 +209,8 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade } events = append(events, expirationEvents...) + // Schedule jobs. if s.clock.Now().Sub(s.previousSchedulingRoundEnd) > s.schedulePeriod { - // Schedule jobs. overallSchedulerResult, err := s.schedulingAlgo.Schedule(ctx, txn, s.jobDb) if err != nil { return err @@ -222,8 +222,6 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade } events = append(events, resultEvents...) s.previousSchedulingRoundEnd = s.clock.Now() - } else { - log.Infof("skipping scheduling new jobs this cycle as a scheduling round ran less than %s ago", s.schedulePeriod) } // Publish to Pulsar. @@ -264,7 +262,7 @@ func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) { // Try and retrieve the job from the jobDb. If it doesn't exist then create it. job := s.jobDb.GetById(txn, dbJob.JobID) if job == nil { - job, err = s.createSchedulerJob(&dbJob) + job, err = s.schedulerJobFromDatabaseJob(&dbJob) if err != nil { return nil, err } @@ -346,8 +344,6 @@ func (s *Scheduler) createSchedulingInfoWithNodeAntiAffinityForAttemptedRuns(job } } podSchedulingRequirement.Affinity = newAffinity - podSchedulingRequirement.ClearCachedSchedulingKey() - return newSchedulingInfo, nil } @@ -817,8 +813,8 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura } } -// createSchedulerJob creates a new scheduler job from a database job. -func (s *Scheduler) createSchedulerJob(dbJob *database.Job) (*jobdb.Job, error) { +// schedulerJobFromDatabaseJob creates a new scheduler job from a database job. +func (s *Scheduler) schedulerJobFromDatabaseJob(dbJob *database.Job) (*jobdb.Job, error) { schedulingInfo := &schedulerobjects.JobSchedulingInfo{} err := proto.Unmarshal(dbJob.SchedulingInfo, schedulingInfo) if err != nil { @@ -892,7 +888,7 @@ func updateSchedulerRun(run *jobdb.JobRun, dbRun *database.Run) *jobdb.JobRun { return run } -// updateSchedulerJob updates the scheduler job (in-place) to match the database job +// updateSchedulerJob updates the scheduler job in-place to match the database job. func updateSchedulerJob(job *jobdb.Job, dbJob *database.Job) (*jobdb.Job, error) { if dbJob.CancelRequested && !job.CancelRequested() { job = job.WithCancelRequested(true) diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index fd1b8ade516..625998ef443 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -115,7 +115,7 @@ func Run(config schedulerconfig.Configuration) error { defer grpcServer.GracefulStop() lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.Grpc.Port)) if err != nil { - return errors.WithMessage(err, "error setting up grpc server") + return errors.WithMessage(err, "error setting up gRPC server") } allowedPcs := config.Scheduling.Preemption.AllowedPriorities() executorServer, err := NewExecutorApi( @@ -127,6 +127,7 @@ func Run(config schedulerconfig.Configuration) error { config.Scheduling.MaximumJobsToSchedule, config.Scheduling.Preemption.NodeIdLabel, config.Scheduling.Preemption.PriorityClassNameOverride, + config.Pulsar.MaxAllowedMessageSize, ) if err != nil { return errors.WithMessage(err, "error creating executorApi") diff --git a/internal/scheduler/schedulerobjects/podutils.go b/internal/scheduler/schedulerobjects/podutils.go index 287401d2826..9bd476ca99a 100644 --- a/internal/scheduler/schedulerobjects/podutils.go +++ b/internal/scheduler/schedulerobjects/podutils.go @@ -212,25 +212,6 @@ func (skg *PodRequirementsSerialiser) AppendResourceList(out []byte, resourceLis return out } -// ClearCachedSchedulingKey clears any cached scheduling keys. -// Necessary after changing scheduling requirements to avoid inconsistency. -func (jobSchedulingInfo *JobSchedulingInfo) ClearCachedSchedulingKey() { - if jobSchedulingInfo == nil { - return - } - for _, objReq := range jobSchedulingInfo.ObjectRequirements { - if req := objReq.GetPodRequirements(); req != nil { - req.ClearCachedSchedulingKey() - } - } -} - -// ClearCachedSchedulingKey clears any cached scheduling key. -// Necessary after changing scheduling requirements to avoid inconsistency. -func (req *PodRequirements) ClearCachedSchedulingKey() { - req.CachedSchedulingKey = nil -} - func lessToleration(a, b v1.Toleration) bool { if a.Key < b.Key { return true diff --git a/internal/scheduler/schedulerobjects/resourcelist.go b/internal/scheduler/schedulerobjects/resourcelist.go index 7c98a95eff3..9a3c67eb5e3 100644 --- a/internal/scheduler/schedulerobjects/resourcelist.go +++ b/internal/scheduler/schedulerobjects/resourcelist.go @@ -2,6 +2,7 @@ package schedulerobjects import ( "fmt" + math "math" "strings" v1 "k8s.io/api/core/v1" @@ -303,6 +304,17 @@ func (rl ResourceList) CompactString() string { return sb.String() } +// AsWeightedMillis returns the linear combination of the milli values in rl with given weights. +// This function overflows for values greater than MaxInt64. E.g., 1Pi is fine but not 10Pi. +func (rl *ResourceList) AsWeightedMillis(weights map[string]float64) int64 { + var rv int64 + for t, w := range weights { + q := rl.Get(t) + rv += int64(math.Round(float64(q.MilliValue()) * w)) + } + return rv +} + func (rl *ResourceList) initialise() { if rl.Resources == nil { rl.Resources = make(map[string]resource.Quantity) diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index fc46ecd16bc..cc91f19709c 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -264,7 +264,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx context.Context, t } jobsByExecutorId[executorId] = append(jobsByExecutorId[executorId], job) nodeIdByJobId[job.Id()] = nodeId - gangId, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(job, l.config.Preemption.PriorityClasses) + gangId, _, isGangJob, err := GangIdAndCardinalityFromLegacySchedulerJob(job) if err != nil { return nil, err } @@ -327,7 +327,11 @@ func (l *FairSchedulingAlgo) scheduleOnExecutor( if allocatedByQueueAndPriorityClass := accounting.allocationByPoolAndQueueAndPriorityClass[executor.Pool]; allocatedByQueueAndPriorityClass != nil { allocatedByPriorityClass = allocatedByQueueAndPriorityClass[queue] } - if err := sctx.AddQueueSchedulingContext(queue, priorityFactor, allocatedByPriorityClass); err != nil { + var weight float64 = 1 + if priorityFactor > 0 { + weight = 1 / priorityFactor + } + if err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByPriorityClass); err != nil { return nil, nil, err } } @@ -342,6 +346,7 @@ func (l *FairSchedulingAlgo) scheduleOnExecutor( constraints, l.config.Preemption.NodeEvictionProbability, l.config.Preemption.NodeOversubscriptionEvictionProbability, + l.config.Preemption.ProtectedFractionOfFairShare, &schedulerJobRepositoryAdapter{ txn: txn, db: db, diff --git a/internal/scheduler/testfixtures/testfixtures.go b/internal/scheduler/testfixtures/testfixtures.go index 559c14b3d33..07ccae0f531 100644 --- a/internal/scheduler/testfixtures/testfixtures.go +++ b/internal/scheduler/testfixtures/testfixtures.go @@ -88,7 +88,7 @@ func ContextWithDefaultLogger(ctx context.Context) context.Context { func TestSchedulingConfig() configuration.SchedulingConfig { return configuration.SchedulingConfig{ - ResourceScarcity: map[string]float64{"cpu": 1, "memory": 0}, + ResourceScarcity: map[string]float64{"cpu": 1}, Preemption: configuration.PreemptionConfig{ PriorityClasses: maps.Clone(TestPriorityClasses), DefaultPriorityClass: TestDefaultPriorityClass, @@ -101,8 +101,13 @@ func TestSchedulingConfig() configuration.SchedulingConfig { } } -func WithMaxUnacknowledgedJobsPerExecutor(i uint, config configuration.SchedulingConfig) configuration.SchedulingConfig { - config.MaxUnacknowledgedJobsPerExecutor = i +func WithMaxUnacknowledgedJobsPerExecutorConfig(v uint, config configuration.SchedulingConfig) configuration.SchedulingConfig { + config.MaxUnacknowledgedJobsPerExecutor = v + return config +} + +func WithProtectedFractionOfFairShareConfig(v float64, config configuration.SchedulingConfig) configuration.SchedulingConfig { + config.Preemption.ProtectedFractionOfFairShare = v return config } @@ -266,7 +271,7 @@ func WithRequestsPodReqs(rl schedulerobjects.ResourceList, reqs []*schedulerobje func WithNodeSelectorJobs(selector map[string]string, jobs []*jobdb.Job) []*jobdb.Job { for _, job := range jobs { - for _, req := range job.GetRequirements(nil).GetObjectRequirements() { + for _, req := range job.GetJobSchedulingInfo(nil).GetObjectRequirements() { req.GetPodRequirements().NodeSelector = maps.Clone(selector) } } @@ -284,7 +289,7 @@ func WithGangAnnotationsJobs(jobs []*jobdb.Job) []*jobdb.Job { func WithAnnotationsJobs(annotations map[string]string, jobs []*jobdb.Job) []*jobdb.Job { for _, job := range jobs { - for _, req := range job.GetRequirements(nil).GetObjectRequirements() { + for _, req := range job.GetJobSchedulingInfo(nil).GetObjectRequirements() { if req.GetPodRequirements().Annotations == nil { req.GetPodRequirements().Annotations = make(map[string]string) } diff --git a/internal/scheduleringester/dbops.go b/internal/scheduleringester/dbops.go index f4b448dec47..b4300cc868a 100644 --- a/internal/scheduleringester/dbops.go +++ b/internal/scheduleringester/dbops.go @@ -99,7 +99,7 @@ func AppendDbOperation(ops []DbOperation, op DbOperation) []DbOperation { break } } - return discardNilOps(ops) // TODO: Can be made more efficient. + return discardNilOps(ops) } func discardNilOps(ops []DbOperation) []DbOperation { diff --git a/internal/scheduleringester/instructions.go b/internal/scheduleringester/instructions.go index 3ff86553121..fa0e70c09a9 100644 --- a/internal/scheduleringester/instructions.go +++ b/internal/scheduleringester/instructions.go @@ -48,7 +48,7 @@ func NewInstructionConverter( func (c *InstructionConverter) Convert(_ context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *DbOperationsWithMessageIds { operations := make([]DbOperation, 0) for _, es := range sequencesWithIds.EventSequences { - for _, op := range c.convertSequence(es) { + for _, op := range c.dbOperationsFromEventSequence(es) { operations = AppendDbOperation(operations, op) } } @@ -58,14 +58,13 @@ func (c *InstructionConverter) Convert(_ context.Context, sequencesWithIds *inge } } -func (c *InstructionConverter) convertSequence(es *armadaevents.EventSequence) []DbOperation { +func (c *InstructionConverter) dbOperationsFromEventSequence(es *armadaevents.EventSequence) []DbOperation { meta := eventSequenceCommon{ queue: es.Queue, jobset: es.JobSetName, user: es.UserId, groups: es.Groups, } - operations := make([]DbOperation, 0, len(es.Events)) for idx, event := range es.Events { eventTime := time.Now().UTC() @@ -117,7 +116,7 @@ func (c *InstructionConverter) convertSequence(es *armadaevents.EventSequence) [ } if err != nil { c.metrics.RecordPulsarMessageError(metrics.PulsarMessageErrorProcessing) - log.WithError(err).Warnf("Could not convert event at index %d.", idx) + log.WithError(err).Errorf("Could not convert event at index %d.", idx) } else { operations = append(operations, operationsFromEvent...) } @@ -148,7 +147,7 @@ func (c *InstructionConverter) handleSubmitJob(job *armadaevents.SubmitJob, subm // Produce a minimal representation of the job for the scheduler. // To avoid the scheduler needing to load the entire job spec. - schedulingInfo, err := c.schedulingInfoFromSubmitJob(job) + schedulingInfo, err := c.schedulingInfoFromSubmitJob(job, submitTime) if err != nil { return nil, err } @@ -357,13 +356,15 @@ func (c *InstructionConverter) handlePartitionMarker(pm *armadaevents.PartitionM // schedulingInfoFromSubmitJob returns a minimal representation of a job // containing only the info needed by the scheduler. -func (c *InstructionConverter) schedulingInfoFromSubmitJob(submitJob *armadaevents.SubmitJob) (*schedulerobjects.JobSchedulingInfo, error) { +func (c *InstructionConverter) schedulingInfoFromSubmitJob(submitJob *armadaevents.SubmitJob, submitTime time.Time) (*schedulerobjects.JobSchedulingInfo, error) { // Component common to all jobs. schedulingInfo := &schedulerobjects.JobSchedulingInfo{ Lifetime: submitJob.Lifetime, AtMostOnce: submitJob.AtMostOnce, Preemptible: submitJob.Preemptible, ConcurrencySafe: submitJob.ConcurrencySafe, + SubmitTime: submitTime, + Priority: submitJob.Priority, Version: 0, } @@ -371,12 +372,16 @@ func (c *InstructionConverter) schedulingInfoFromSubmitJob(submitJob *armadaeven switch object := submitJob.MainObject.Object.(type) { case *armadaevents.KubernetesMainObject_PodSpec: podSpec := object.PodSpec.PodSpec - requirements := &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: adapters.PodRequirementsFromPodSpec(podSpec, c.priorityClasses), - } + schedulingInfo.PriorityClassName = podSpec.PriorityClassName + podRequirements := adapters.PodRequirementsFromPodSpec(podSpec, c.priorityClasses) + podRequirements.Annotations = submitJob.ObjectMeta.Annotations schedulingInfo.ObjectRequirements = append( schedulingInfo.ObjectRequirements, - &schedulerobjects.ObjectRequirements{Requirements: requirements}, + &schedulerobjects.ObjectRequirements{ + Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ + PodRequirements: podRequirements, + }, + }, ) default: return nil, errors.Errorf("unsupported object type %T", object) diff --git a/internal/scheduleringester/instructions_test.go b/internal/scheduleringester/instructions_test.go index 45debe5c30d..99ecde22f1a 100644 --- a/internal/scheduleringester/instructions_test.go +++ b/internal/scheduleringester/instructions_test.go @@ -201,7 +201,7 @@ func TestConvertSequence(t *testing.T) { t.Run(name, func(t *testing.T) { converter := InstructionConverter{m, f.PriorityClasses, compressor} es := f.NewEventSequence(tc.events...) - results := converter.convertSequence(es) + results := converter.dbOperationsFromEventSequence(es) assertOperationsEqual(t, tc.expected, results) }) } @@ -272,11 +272,14 @@ func assertErrorMessagesEqual(t *testing.T, expectedBytes []byte, actualBytes [] func getExpectedSubmitMessageSchedulingInfo(t *testing.T) *schedulerobjects.JobSchedulingInfo { expectedSubmitSchedulingInfo := &schedulerobjects.JobSchedulingInfo{ - Lifetime: 0, - AtMostOnce: true, - Preemptible: true, - ConcurrencySafe: true, - Version: 0, + Lifetime: 0, + AtMostOnce: true, + Preemptible: true, + ConcurrencySafe: true, + Version: 0, + PriorityClassName: "test-priority", + Priority: 3, + SubmitTime: f.BaseTime, ObjectRequirements: []*schedulerobjects.ObjectRequirements{ { Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ diff --git a/internal/scheduleringester/schedulerdb.go b/internal/scheduleringester/schedulerdb.go index ad34ada01c9..9b944836c2e 100644 --- a/internal/scheduleringester/schedulerdb.go +++ b/internal/scheduleringester/schedulerdb.go @@ -42,10 +42,9 @@ func NewSchedulerDb( } } -// Store persists all operations in the database. Note that: -// - this function will retry until it either succeeds or a terminal error is encountered -// - this function will take out a postgres lock to ensure that other ingesters are not writing to the database -// at the same time (for details, see acquireLock()) +// Store persists all operations in the database. +// This function retires until it either succeeds or encounters a terminal error. +// This function locks the postgres table to avoid write conflicts; see acquireLock() for details. func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithMessageIds) error { return ingest.WithRetry(func() (bool, error) { err := s.db.BeginTxFunc(ctx, pgx.TxOptions{ @@ -53,36 +52,38 @@ func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithM AccessMode: pgx.ReadWrite, DeferrableMode: pgx.Deferrable, }, func(tx pgx.Tx) error { - // First acquire the write lock lockCtx, cancel := context.WithTimeout(ctx, s.lockTimeout) defer cancel() - err := s.acquireLock(lockCtx, tx) - if err != nil { + // The lock is released automatically on transaction rollback/commit. + if err := s.acquireLock(lockCtx, tx); err != nil { return err } - // Now insert the ops for _, dbOp := range instructions.Ops { - err := s.WriteDbOp(ctx, tx, dbOp) - if err != nil { + if err := s.WriteDbOp(ctx, tx, dbOp); err != nil { return err } } - return err + return nil }) return true, err }, s.initialBackOff, s.maxBackOff) } -// acquireLock acquires the armada_scheduleringester_lock, which prevents two ingesters writing to the db at the same -// time. This is necessary because: -// - when rows are inserted into the database they are stamped with a sequence number -// - the scheduler relies on this sequence number increasing to ensure it has fetched all updated rows -// - concurrent transactions will result in sequence numbers being interleaved across transactions. -// - the interleaved sequences may result in the scheduler seeing sequence numbers that do not strictly increase over time. +// acquireLock acquires a postgres advisory lock, thus preventing concurrent writes. +// This is necessary to ensure sequence numbers assigned to each inserted row are monotonically increasing. +// Such a sequence number is assigned to each inserted row by a postgres function. +// +// Hence, if rows are inserted across multiple transactions concurrently, +// sequence numbers may be interleaved between transactions and the slower transaction may insert +// rows with sequence numbers smaller than those already written. +// +// The scheduler relies on these sequence numbers to only fetch new or updated rows in each update cycle. func (s *SchedulerDb) acquireLock(ctx context.Context, tx pgx.Tx) error { const lockId = 8741339439634283896 - _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", lockId) - return errors.Wrapf(err, "Could not obtain lock") + if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", lockId); err != nil { + return errors.Wrapf(err, "could not obtain lock") + } + return nil } func (s *SchedulerDb) WriteDbOp(ctx context.Context, tx pgx.Tx, op DbOperation) error { diff --git a/pkg/api/util.go b/pkg/api/util.go index 45e047bfe36..399a3461446 100644 --- a/pkg/api/util.go +++ b/pkg/api/util.go @@ -102,11 +102,22 @@ func JobRunStateFromApiJobState(s JobState) schedulerobjects.JobRunState { return schedulerobjects.JobRunState_UNKNOWN } -func NewNodeTypeFromNodeInfo(nodeInfo *NodeInfo, indexedTaints map[string]interface{}, indexedLabels map[string]interface{}) *schedulerobjects.NodeType { - return schedulerobjects.NewNodeType(nodeInfo.GetTaints(), nodeInfo.GetLabels(), indexedTaints, indexedLabels) +func (job *Job) GetPerQueuePriority() uint32 { + priority := job.Priority + if priority < 0 { + return 0 + } + if priority > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(math.Round(priority)) +} + +func (job *Job) GetSubmitTime() time.Time { + return job.Created } -func (job *Job) GetRequirements(priorityClasses map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo { +func (job *Job) GetJobSchedulingInfo(priorityClasses map[string]configuration.PriorityClass) *schedulerobjects.JobSchedulingInfo { podSpec := job.GetMainPodSpec() priority, ok := PriorityFromPodSpec(podSpec, priorityClasses) @@ -132,8 +143,8 @@ func (job *Job) GetRequirements(priorityClasses map[string]configuration.Priorit } return &schedulerobjects.JobSchedulingInfo{ PriorityClassName: podSpec.PriorityClassName, - Priority: LogSubmitPriorityFromApiPriority(job.GetPriority()), - SubmitTime: job.GetCreated(), + Priority: job.GetPerQueuePriority(), + SubmitTime: job.GetSubmitTime(), ObjectRequirements: []*schedulerobjects.ObjectRequirements{ { Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ @@ -249,19 +260,6 @@ func (job *Job) GetJobSet() string { return job.JobSetId } -// LogSubmitPriorityFromApiPriority returns the uint32 representation of the priority included with a submitted job, -// or an error if the conversion fails. -func LogSubmitPriorityFromApiPriority(priority float64) uint32 { - if priority < 0 { - priority = 0 - } - if priority > math.MaxUint32 { - priority = math.MaxUint32 - } - priority = math.Round(priority) - return uint32(priority) -} - func (job *Job) GetMainPodSpec() *v1.PodSpec { if job.PodSpec != nil { return job.PodSpec diff --git a/pkg/api/util_test.go b/pkg/api/util_test.go index 13147463501..2e9315f031c 100644 --- a/pkg/api/util_test.go +++ b/pkg/api/util_test.go @@ -503,7 +503,7 @@ func TestJobGetRequirements(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - assert.Equal(t, tc.expected, tc.job.GetRequirements(TestPriorityClasses)) + assert.Equal(t, tc.expected, tc.job.GetJobSchedulingInfo(TestPriorityClasses)) }) } }