diff --git a/internal/common/database/lookout/jobstates.go b/internal/common/database/lookout/jobstates.go index 9ba1ce54f31..20ea463dde5 100644 --- a/internal/common/database/lookout/jobstates.go +++ b/internal/common/database/lookout/jobstates.go @@ -51,6 +51,18 @@ const ( ) var ( + // JobStates is an ordered list of states + JobStates = []JobState{ + JobQueued, + JobLeased, + JobPending, + JobRunning, + JobSucceeded, + JobFailed, + JobCancelled, + JobPreempted, + } + JobStateMap = map[int]JobState{ JobLeasedOrdinal: JobLeased, JobQueuedOrdinal: JobQueued, diff --git a/internal/lookoutv2/conversions/convert_test.go b/internal/lookoutv2/conversions/convert_test.go index 9d5649156ac..32130e63ff3 100644 --- a/internal/lookoutv2/conversions/convert_test.go +++ b/internal/lookoutv2/conversions/convert_test.go @@ -86,16 +86,22 @@ var ( } swaggerGroup = &models.Group{ - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "averageTimeInState": "3d", + "state": map[string]int{ + "QUEUED": 321, + }, }, Count: 1000, Name: "queue-1", } group = &model.JobGroup{ - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "averageTimeInState": "3d", + "state": map[string]int{ + "QUEUED": 321, + }, }, Count: 1000, Name: "queue-1", diff --git a/internal/lookoutv2/gen/models/group.go b/internal/lookoutv2/gen/models/group.go index 71adda73be1..25c8d68892a 100644 --- a/internal/lookoutv2/gen/models/group.go +++ b/internal/lookoutv2/gen/models/group.go @@ -21,7 +21,7 @@ type Group struct { // aggregates // Required: true - Aggregates map[string]string `json:"aggregates"` + Aggregates map[string]interface{} `json:"aggregates"` // count // Required: true @@ -61,6 +61,14 @@ func (m *Group) validateAggregates(formats strfmt.Registry) error { return err } + for k := range m.Aggregates { + + if err := validate.Required("aggregates"+"."+k, "body", m.Aggregates[k]); err != nil { + return err + } + + } + return nil } diff --git a/internal/lookoutv2/gen/restapi/embedded_spec.go b/internal/lookoutv2/gen/restapi/embedded_spec.go index 629d7b45500..c57b6290da2 100644 --- a/internal/lookoutv2/gen/restapi/embedded_spec.go +++ b/internal/lookoutv2/gen/restapi/embedded_spec.go @@ -426,7 +426,7 @@ func init() { "aggregates": { "type": "object", "additionalProperties": { - "type": "string" + "type": "object" }, "x-nullable": false }, @@ -1082,7 +1082,7 @@ func init() { "aggregates": { "type": "object", "additionalProperties": { - "type": "string" + "type": "object" }, "x-nullable": false }, diff --git a/internal/lookoutv2/model/model.go b/internal/lookoutv2/model/model.go index 349541d54cb..0d22f87ec3c 100644 --- a/internal/lookoutv2/model/model.go +++ b/internal/lookoutv2/model/model.go @@ -53,7 +53,7 @@ type Run struct { } type JobGroup struct { - Aggregates map[string]string + Aggregates map[string]interface{} Count int64 Name string } diff --git a/internal/lookoutv2/repository/aggregates.go b/internal/lookoutv2/repository/aggregates.go new file mode 100644 index 00000000000..ad7c1386dba --- /dev/null +++ b/internal/lookoutv2/repository/aggregates.go @@ -0,0 +1,133 @@ +package repository + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/database/lookout" + "github.com/armadaproject/armada/internal/common/util" + "github.com/armadaproject/armada/internal/lookoutv2/model" +) + +type QueryAggregator interface { + AggregateSql() (string, error) +} + +type SqlFunctionAggregator struct { + queryCol *queryColumn + sqlFunction string +} + +func NewSqlFunctionAggregator(queryCol *queryColumn, fn string) *SqlFunctionAggregator { + return &SqlFunctionAggregator{ + queryCol: queryCol, + sqlFunction: fn, + } +} + +func (qa *SqlFunctionAggregator) aggregateColName() string { + return qa.queryCol.name +} + +func (qa *SqlFunctionAggregator) AggregateSql() (string, error) { + return fmt.Sprintf("%s(%s.%s) AS %s", qa.sqlFunction, qa.queryCol.abbrev, qa.queryCol.name, qa.aggregateColName()), nil +} + +type StateCountAggregator struct { + queryCol *queryColumn + stateString string +} + +func NewStateCountAggregator(queryCol *queryColumn, stateString string) *StateCountAggregator { + return &StateCountAggregator{ + queryCol: queryCol, + stateString: stateString, + } +} + +func (qa *StateCountAggregator) aggregateColName() string { + return fmt.Sprintf("%s_%s", qa.queryCol.name, qa.stateString) +} + +func (qa *StateCountAggregator) AggregateSql() (string, error) { + stateInt, ok := lookout.JobStateOrdinalMap[lookout.JobState(qa.stateString)] + if !ok { + return "", errors.Errorf("state %s does not exist", qa.stateString) + } + return fmt.Sprintf( + "SUM(CASE WHEN %s.%s = %d THEN 1 ELSE 0 END) AS %s", + qa.queryCol.abbrev, qa.queryCol.name, stateInt, qa.aggregateColName(), + ), nil +} + +func GetAggregatorsForColumn(queryCol *queryColumn, aggregateType AggregateType, filters []*model.Filter) ([]QueryAggregator, error) { + switch aggregateType { + case Max: + return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "MAX")}, nil + case Average: + return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "AVG")}, nil + case StateCounts: + states := GetStatesForFilter(filters) + aggregators := make([]QueryAggregator, len(states)) + for i, state := range states { + aggregators[i] = NewStateCountAggregator(queryCol, state) + } + return aggregators, nil + default: + return nil, errors.Errorf("cannot determine aggregate type: %v", aggregateType) + } +} + +// GetStatesForFilter returns a list of states as string if filter for state exists +// Will always return the states in the same order, irrespective of the ordering of the states in the filter +func GetStatesForFilter(filters []*model.Filter) []string { + var stateFilter *model.Filter + for _, f := range filters { + if f.Field == stateField { + stateFilter = f + } + } + allStates := util.Map(lookout.JobStates, func(jobState lookout.JobState) string { return string(jobState) }) + if stateFilter == nil { + // If no state filter is specified, use all states + return allStates + } + + switch stateFilter.Match { + case model.MatchExact: + return []string{fmt.Sprintf("%s", stateFilter.Value)} + case model.MatchAnyOf: + strSlice, err := toStringSlice(stateFilter.Value) + if err != nil { + return allStates + } + stateStringSet := util.StringListToSet(strSlice) + // Ensuring they are in the same order + var finalStates []string + for _, state := range allStates { + if _, ok := stateStringSet[state]; ok { + finalStates = append(finalStates, state) + } + } + return finalStates + default: + return allStates + } +} + +func toStringSlice(val interface{}) ([]string, error) { + switch v := val.(type) { + case []string: + return v, nil + case []interface{}: + result := make([]string, len(v)) + for i := 0; i < len(v); i++ { + str := fmt.Sprintf("%v", v[i]) + result[i] = str + } + return result, nil + default: + return nil, errors.Errorf("failed to convert interface to string slice: %v of type %T", val, val) + } +} diff --git a/internal/lookoutv2/repository/fieldparser.go b/internal/lookoutv2/repository/fieldparser.go new file mode 100644 index 00000000000..e8ddde0996b --- /dev/null +++ b/internal/lookoutv2/repository/fieldparser.go @@ -0,0 +1,122 @@ +package repository + +import ( + "fmt" + "math" + "time" + + "github.com/jackc/pgtype" + "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/database/lookout" + "github.com/armadaproject/armada/internal/lookoutv2/model" +) + +type FieldParser interface { + GetField() string + GetVariableRef() interface{} + ParseValue() (interface{}, error) +} + +type LastTransitionTimeParser struct { + variable pgtype.Numeric +} + +func (fp *LastTransitionTimeParser) GetField() string { + return lastTransitionTimeField +} + +func (fp *LastTransitionTimeParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *LastTransitionTimeParser) ParseValue() (interface{}, error) { + var dst float64 + err := fp.variable.AssignTo(&dst) + if err != nil { + return "", err + } + t := time.Unix(int64(math.Round(dst)), 0) + return t.Format(time.RFC3339), nil +} + +type TimeParser struct { + field string + variable time.Time +} + +func (fp *TimeParser) GetField() string { + return fp.field +} + +func (fp *TimeParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *TimeParser) ParseValue() (interface{}, error) { + return fp.variable.Format(time.RFC3339), nil +} + +type StateParser struct { + variable int16 +} + +func (fp *StateParser) GetField() string { + return stateField +} + +func (fp *StateParser) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *StateParser) ParseValue() (interface{}, error) { + state, ok := lookout.JobStateMap[int(fp.variable)] + if !ok { + return "", errors.Errorf("state not found: %d", fp.variable) + } + return string(state), nil +} + +type BasicParser[T any] struct { + field string + variable T +} + +func (fp *BasicParser[T]) GetField() string { + return fp.field +} + +func (fp *BasicParser[T]) GetVariableRef() interface{} { + return &fp.variable +} + +func (fp *BasicParser[T]) ParseValue() (interface{}, error) { + return fp.variable, nil +} + +func ParserForGroup(field string) FieldParser { + switch field { + case stateField: + return &StateParser{} + default: + return &BasicParser[string]{field: field} + } +} + +func ParsersForAggregate(field string, filters []*model.Filter) ([]FieldParser, error) { + var parsers []FieldParser + switch field { + case lastTransitionTimeField: + parsers = append(parsers, &LastTransitionTimeParser{}) + case submittedField: + parsers = append(parsers, &TimeParser{field: submittedField}) + case stateField: + states := GetStatesForFilter(filters) + for _, state := range states { + parsers = append(parsers, &BasicParser[int]{field: fmt.Sprintf("%s%s", stateAggregatePrefix, state)}) + } + default: + return nil, errors.Errorf("no aggregate found for field %s", field) + } + return parsers, nil +} diff --git a/internal/lookoutv2/repository/groupjobs.go b/internal/lookoutv2/repository/groupjobs.go index 1988e4a31ce..f8fe0b37206 100644 --- a/internal/lookoutv2/repository/groupjobs.go +++ b/internal/lookoutv2/repository/groupjobs.go @@ -2,16 +2,14 @@ package repository import ( "context" - "math" - "time" + "fmt" + "strings" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/pkg/errors" "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" ) @@ -39,15 +37,7 @@ type SqlGroupJobsRepository struct { lookoutTables *LookoutTables } -type scanVarInit func() interface{} - -type parserFn func(interface{}) (string, error) - -type scanContext struct { - field string - varInit scanVarInit - parser parserFn -} +const stateAggregatePrefix = "state_" func NewSqlGroupJobsRepository(db *pgxpool.Pool) *SqlGroupJobsRepository { return &SqlGroupJobsRepository{ @@ -95,7 +85,7 @@ func (r *SqlGroupJobsRepository) GroupBy( if err != nil { return err } - groups, err = rowsToGroups(groupRows, groupedField, aggregates) + groups, err = rowsToGroups(groupRows, groupedField, aggregates, filters) return err }) if err != nil { @@ -108,10 +98,10 @@ func (r *SqlGroupJobsRepository) GroupBy( }, nil } -func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates []string) ([]*model.JobGroup, error) { +func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates []string, filters []*model.Filter) ([]*model.JobGroup, error) { var groups []*model.JobGroup for rows.Next() { - jobGroup, err := scanGroup(rows, groupedField.Field, aggregates) + jobGroup, err := scanGroup(rows, groupedField.Field, aggregates, filters) if err != nil { return nil, err } @@ -120,143 +110,59 @@ func rowsToGroups(rows pgx.Rows, groupedField *model.GroupedField, aggregates [] return groups, nil } -func scanGroup(rows pgx.Rows, field string, aggregates []string) (*model.JobGroup, error) { - groupScanContext, err := groupScanContextForField(field) - if err != nil { - return nil, err - } - group := groupScanContext.varInit() +func scanGroup(rows pgx.Rows, field string, aggregates []string, filters []*model.Filter) (*model.JobGroup, error) { + groupParser := ParserForGroup(field) var count int64 - - scanContexts := make([]*scanContext, len(aggregates)) - aggregateVars := make([]interface{}, len(aggregates)) - for i, aggregate := range aggregates { - sc, err := aggregateScanContextForField(aggregate) + var aggregateParsers []FieldParser + for _, aggregate := range aggregates { + parsers, err := ParsersForAggregate(aggregate, filters) if err != nil { return nil, err } - aggregateVars[i] = sc.varInit() - scanContexts[i] = sc + aggregateParsers = append(aggregateParsers, parsers...) } - aggregateRefs := make([]interface{}, len(aggregates)) - for i := 0; i < len(aggregates); i++ { - aggregateRefs[i] = &aggregateVars[i] + aggregateRefs := make([]interface{}, len(aggregateParsers)) + for i, parser := range aggregateParsers { + aggregateRefs[i] = parser.GetVariableRef() } - varAddresses := util.Concat([]interface{}{&group, &count}, aggregateRefs) - err = rows.Scan(varAddresses...) + varAddresses := util.Concat([]interface{}{groupParser.GetVariableRef(), &count}, aggregateRefs) + err := rows.Scan(varAddresses...) if err != nil { return nil, err } - parsedGroup, err := groupScanContext.parser(group) + parsedGroup, err := groupParser.ParseValue() if err != nil { return nil, err } - aggregatesMap := make(map[string]string) - for i, sc := range scanContexts { - val := aggregateVars[i] - parsedVal, err := sc.parser(val) + aggregatesMap := make(map[string]interface{}) + for _, parser := range aggregateParsers { + val, err := parser.ParseValue() if err != nil { - return nil, errors.Wrapf(err, "failed to parse value for field %s", sc.field) + return nil, errors.Wrapf(err, "failed to parse value for field %s", parser.GetField()) + } + if strings.HasPrefix(parser.GetField(), stateAggregatePrefix) { + singleStateCount, ok := val.(int) + if !ok { + return nil, errors.Errorf("failed to parse value for state aggregate: cannot convert value to int: %v: %T", singleStateCount, singleStateCount) + } + stateCountsVal, ok := aggregatesMap[stateField] + if !ok { + stateCountsVal = map[string]int{} + aggregatesMap[stateField] = stateCountsVal + } + stateCounts, ok := stateCountsVal.(map[string]int) + if !ok { + return nil, errors.Errorf("failed to parse value for state aggregate: cannot cast state counts to map") + } + state := parser.GetField()[len(stateAggregatePrefix):] + stateCounts[state] = singleStateCount + } else { + aggregatesMap[parser.GetField()] = val } - aggregatesMap[sc.field] = parsedVal } return &model.JobGroup{ - Name: parsedGroup, + Name: fmt.Sprintf("%s", parsedGroup), Count: count, Aggregates: aggregatesMap, }, nil } - -func groupScanContextForField(field string) (*scanContext, error) { - switch field { - case stateField: - return &scanContext{ - field: field, - varInit: int16ScanVar, - parser: stateParser, - }, nil - default: - return &scanContext{ - field: field, - varInit: stringScanVar, - parser: stringParser, - }, nil - } -} - -func aggregateScanContextForField(field string) (*scanContext, error) { - switch field { - case lastTransitionTimeField: - return &scanContext{ - field: lastTransitionTimeField, - varInit: numericScanVar, - parser: avgLastTransitionTimeParser, - }, nil - case submittedField: - return &scanContext{ - field: submittedField, - varInit: timeScanVar, - parser: maxSubmittedTimeParser, - }, nil - default: - return nil, errors.Errorf("no aggregate found for field %s", field) - } -} - -func stringScanVar() interface{} { - return "" -} - -func int16ScanVar() interface{} { - return int16(0) -} - -func numericScanVar() interface{} { - return pgtype.Numeric{} -} - -func timeScanVar() interface{} { - return time.Time{} -} - -func avgLastTransitionTimeParser(val interface{}) (string, error) { - lastTransitionTimeSeconds, ok := val.(pgtype.Numeric) - if !ok { - return "", errors.Errorf("could not convert %v: %T to int64", val, val) - } - var dst float64 - err := lastTransitionTimeSeconds.AssignTo(&dst) - if err != nil { - return "", err - } - t := time.Unix(int64(math.Round(dst)), 0) - return t.Format(time.RFC3339), nil -} - -func maxSubmittedTimeParser(val interface{}) (string, error) { - maxSubmittedTime, ok := val.(time.Time) - if !ok { - return "", errors.Errorf("could not convert %v: %T to time", val, val) - } - return maxSubmittedTime.Format(time.RFC3339), nil -} - -func stateParser(val interface{}) (string, error) { - stateInt, ok := val.(int16) - if !ok { - return "", errors.Errorf("could not convert %v: %T to int for state", val, val) - } - state, ok := lookout.JobStateMap[int(stateInt)] - if !ok { - return "", errors.Errorf("state not found: %d", stateInt) - } - return string(state), nil -} - -func stringParser(val interface{}) (string, error) { - str, ok := val.(string) - if !ok { - return "", errors.Errorf("could not convert %v: %T to string", val, val) - } - return str, nil -} diff --git a/internal/lookoutv2/repository/groupjobs_test.go b/internal/lookoutv2/repository/groupjobs_test.go index 29fb24a507c..2ca98fd8a26 100644 --- a/internal/lookoutv2/repository/groupjobs_test.go +++ b/internal/lookoutv2/repository/groupjobs_test.go @@ -59,17 +59,17 @@ func TestGroupByQueue(t *testing.T) { { Name: "queue-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "queue-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "queue-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -117,17 +117,17 @@ func TestGroupByJobSet(t *testing.T) { { Name: "job-set-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "job-set-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "job-set-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -183,22 +183,22 @@ func TestGroupByState(t *testing.T) { { Name: string(lookout.JobQueued), Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobPending), Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobRunning), Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobFailed), Count: 2, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -370,22 +370,22 @@ func TestGroupByWithFilters(t *testing.T) { { Name: string(lookout.JobQueued), Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobPending), Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobRunning), Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: string(lookout.JobFailed), Count: 2, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -468,21 +468,21 @@ func TestGroupJobsWithMaxSubmittedTime(t *testing.T) { { Name: "job-set-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Format(time.RFC3339), }, }, { Name: "job-set-2", Count: 12, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(-4 * time.Minute).Format(time.RFC3339), }, }, { Name: "job-set-3", Count: 18, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(-7 * time.Minute).Format(time.RFC3339), }, }, @@ -567,21 +567,21 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { { Name: "queue-3", Count: 18, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-8 * time.Minute).Format(time.RFC3339), }, }, { Name: "queue-2", Count: 12, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-5 * time.Minute).Format(time.RFC3339), }, }, { Name: "queue-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "lastTransitionTime": baseTime.Add(-1 * time.Minute).Format(time.RFC3339), }, }, @@ -591,6 +591,237 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { assert.NoError(t, err) } +func TestGroupJobsWithAllStateCounts(t *testing.T) { + err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { + converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, false) + store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) + + manyJobs(5, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobQueued, + }, converter, store) + manyJobs(6, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobPending, + }, converter, store) + manyJobs(7, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobRunning, + }, converter, store) + + manyJobs(8, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobLeased, + }, converter, store) + manyJobs(9, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobPreempted, + }, converter, store) + manyJobs(10, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobCancelled, + }, converter, store) + + manyJobs(11, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobSucceeded, + }, converter, store) + manyJobs(12, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobFailed, + }, converter, store) + manyJobs(13, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobQueued, + }, converter, store) + + repo := NewSqlGroupJobsRepository(db) + result, err := repo.GroupBy( + context.TODO(), + []*model.Filter{}, + &model.Order{ + Field: "count", + Direction: "ASC", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{"state"}, + 0, + 10, + ) + assert.NoError(t, err) + assert.Len(t, result.Groups, 3) + assert.Equal(t, 3, result.Count) + assert.Equal(t, []*model.JobGroup{ + { + Name: "job-set-1", + Count: 18, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 5, + string(lookout.JobLeased): 0, + string(lookout.JobPending): 6, + string(lookout.JobRunning): 7, + string(lookout.JobSucceeded): 0, + string(lookout.JobFailed): 0, + string(lookout.JobCancelled): 0, + string(lookout.JobPreempted): 0, + }, + }, + }, + { + Name: "job-set-2", + Count: 27, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 0, + string(lookout.JobLeased): 8, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + string(lookout.JobSucceeded): 0, + string(lookout.JobFailed): 0, + string(lookout.JobCancelled): 10, + string(lookout.JobPreempted): 9, + }, + }, + }, + { + Name: "job-set-3", + Count: 36, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 13, + string(lookout.JobLeased): 0, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + string(lookout.JobSucceeded): 11, + string(lookout.JobFailed): 12, + string(lookout.JobCancelled): 0, + string(lookout.JobPreempted): 0, + }, + }, + }, + }, result.Groups) + return nil + }) + assert.NoError(t, err) +} + +func TestGroupJobsWithFilteredStateCounts(t *testing.T) { + err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { + converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, false) + store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) + + manyJobs(5, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobQueued, + }, converter, store) + manyJobs(6, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobPending, + }, converter, store) + manyJobs(7, &createJobsOpts{ + queue: "queue-1", + jobSet: "job-set-1", + state: lookout.JobRunning, + }, converter, store) + + manyJobs(9, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobPreempted, + }, converter, store) + manyJobs(10, &createJobsOpts{ + queue: "queue-2", + jobSet: "job-set-2", + state: lookout.JobCancelled, + }, converter, store) + + manyJobs(11, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobSucceeded, + }, converter, store) + manyJobs(12, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobFailed, + }, converter, store) + manyJobs(13, &createJobsOpts{ + queue: "queue-3", + jobSet: "job-set-3", + state: lookout.JobQueued, + }, converter, store) + + repo := NewSqlGroupJobsRepository(db) + result, err := repo.GroupBy( + context.TODO(), + []*model.Filter{ + { + Field: stateField, + Match: model.MatchAnyOf, + Value: []string{ + string(lookout.JobQueued), + string(lookout.JobPending), + string(lookout.JobRunning), + }, + }, + }, + &model.Order{ + Field: "count", + Direction: "DESC", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{"state"}, + 0, + 10, + ) + assert.NoError(t, err) + assert.Len(t, result.Groups, 2) + assert.Equal(t, 2, result.Count) + assert.Equal(t, []*model.JobGroup{ + { + Name: "job-set-1", + Count: 18, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 5, + string(lookout.JobPending): 6, + string(lookout.JobRunning): 7, + }, + }, + }, + { + Name: "job-set-3", + Count: 13, + Aggregates: map[string]interface{}{ + "state": map[string]int{ + string(lookout.JobQueued): 13, + string(lookout.JobPending): 0, + string(lookout.JobRunning): 0, + }, + }, + }, + }, result.Groups) + return nil + }) + assert.NoError(t, err) +} + func TestGroupJobsComplex(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { converter := instructions.NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) @@ -709,7 +940,7 @@ func TestGroupJobsComplex(t *testing.T) { { Name: "job-set-2", Count: 2, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(20 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(50 * time.Minute).Format(time.RFC3339), }, @@ -717,7 +948,7 @@ func TestGroupJobsComplex(t *testing.T) { { Name: "job-set-1", Count: 15, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(3 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(5 * time.Minute).Format(time.RFC3339), }, @@ -778,17 +1009,17 @@ func TestGroupByAnnotation(t *testing.T) { { Name: "test-value-1", Count: 10, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "test-value-2", Count: 5, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, { Name: "test-value-3", Count: 3, - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, }, }) return nil @@ -907,7 +1138,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "4", Count: 2, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(20 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(50 * time.Minute).Format(time.RFC3339), }, @@ -915,7 +1146,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "2", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(1 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(10 * time.Minute).Format(time.RFC3339), }, @@ -923,7 +1154,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "3", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Add(3 * time.Minute).Format(time.RFC3339), "lastTransitionTime": baseTime.Add(5 * time.Minute).Format(time.RFC3339), }, @@ -931,7 +1162,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { { Name: "1", Count: 5, - Aggregates: map[string]string{ + Aggregates: map[string]interface{}{ "submitted": baseTime.Format(time.RFC3339), "lastTransitionTime": baseTime.Format(time.RFC3339), }, @@ -960,7 +1191,7 @@ func TestGroupJobsSkip(t *testing.T) { return &model.JobGroup{ Name: fmt.Sprintf("queue-%d", i), Count: int64(i), - Aggregates: map[string]string{}, + Aggregates: map[string]interface{}{}, } } @@ -1160,12 +1391,20 @@ func getCreateJobsFn(state lookout.JobState) createJobsFn { switch state { case lookout.JobQueued: return makeQueued + case lookout.JobLeased: + return makeLeased case lookout.JobPending: return makePending case lookout.JobRunning: return makeRunning + case lookout.JobSucceeded: + return makeSucceeded case lookout.JobFailed: return makeFailed + case lookout.JobCancelled: + return makeCancelled + case lookout.JobPreempted: + return makePreempted default: return makeQueued } @@ -1186,6 +1425,23 @@ func makeQueued(opts *createJobsOpts, converter *instructions.InstructionConvert Build() } +func makeLeased(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Lease(uuid.NewString(), lastTransitionTime). + Build() +} + func makePending(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { tSubmit := baseTime if opts.submittedTime != nil { @@ -1222,6 +1478,27 @@ func makeRunning(opts *createJobsOpts, converter *instructions.InstructionConver Build() } +func makeSucceeded(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + runId := uuid.NewString() + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Pending(runId, cluster, lastTransitionTime.Add(-2*time.Minute)). + Running(runId, cluster, lastTransitionTime.Add(-1*time.Minute)). + RunSucceeded(runId, lastTransitionTime). + Succeeded(lastTransitionTime). + Build() +} + func makeFailed(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { tSubmit := baseTime if opts.submittedTime != nil { @@ -1242,3 +1519,40 @@ func makeFailed(opts *createJobsOpts, converter *instructions.InstructionConvert Failed(node, 1, "error", lastTransitionTime). Build() } + +func makeCancelled(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Cancelled(lastTransitionTime). + Build() +} + +func makePreempted(opts *createJobsOpts, converter *instructions.InstructionConverter, store *lookoutdb.LookoutDb) { + tSubmit := baseTime + if opts.submittedTime != nil { + tSubmit = *opts.submittedTime + } + lastTransitionTime := baseTime + if opts.lastTransitionTime != nil { + lastTransitionTime = *opts.lastTransitionTime + } + runId := uuid.NewString() + NewJobSimulator(converter, store). + Submit(opts.queue, opts.jobSet, owner, tSubmit, &JobOptions{ + Annotations: opts.annotations, + }). + Pending(runId, cluster, lastTransitionTime.Add(-2*time.Minute)). + Running(runId, cluster, lastTransitionTime.Add(-1*time.Minute)). + Preempted(lastTransitionTime). + Build() +} diff --git a/internal/lookoutv2/repository/common.go b/internal/lookoutv2/repository/querybuilder.go similarity index 95% rename from internal/lookoutv2/repository/common.go rename to internal/lookoutv2/repository/querybuilder.go index 33e1725db02..c0999dbd5dd 100644 --- a/internal/lookoutv2/repository/common.go +++ b/internal/lookoutv2/repository/querybuilder.go @@ -58,14 +58,6 @@ type queryOrder struct { direction string } -// Get aggregation expression for column, e.g. MAX(j.submitted) -type aggregatorFn func(column *queryColumn) string - -type queryAggregator struct { - column *queryColumn - aggregator aggregatorFn -} - func NewQueryBuilder(lookoutTables *LookoutTables) *QueryBuilder { return &QueryBuilder{ lookoutTables: lookoutTables, @@ -368,11 +360,14 @@ func (qb *QueryBuilder) GroupBy( if err != nil { return nil, err } - queryAggregators, err := qb.getQueryAggregators(aggregates, queryTables) + queryAggregators, err := qb.getQueryAggregators(aggregates, normalFilters, queryTables) + if err != nil { + return nil, err + } + selectListSql, err := qb.getAggregatesSql(queryAggregators) if err != nil { return nil, err } - selectListSql := qb.getAggregatesSql(queryAggregators) orderSql, err := qb.groupByOrderSql(order) if err != nil { return nil, err @@ -912,9 +907,9 @@ func (qb *QueryBuilder) highestPrecedenceTableForColumn(col string, queryTables return selectedTable, nil } -func (qb *QueryBuilder) getQueryAggregators(aggregates []string, queryTables map[string]bool) ([]*queryAggregator, error) { - queryAggregators := make([]*queryAggregator, len(aggregates)) - for i, aggregate := range aggregates { +func (qb *QueryBuilder) getQueryAggregators(aggregates []string, filters []*model.Filter, queryTables map[string]bool) ([]QueryAggregator, error) { + var queryAggregators []QueryAggregator + for _, aggregate := range aggregates { col, err := qb.lookoutTables.ColumnFromField(aggregate) if err != nil { return nil, err @@ -927,25 +922,25 @@ func (qb *QueryBuilder) getQueryAggregators(aggregates []string, queryTables map if err != nil { return nil, err } - fn, err := getAggregatorFn(aggregateType) + newQueryAggregators, err := GetAggregatorsForColumn(qc, aggregateType, filters) if err != nil { return nil, err } - queryAggregators[i] = &queryAggregator{ - column: qc, - aggregator: fn, - } + queryAggregators = append(queryAggregators, newQueryAggregators...) } return queryAggregators, nil } -func (qb *QueryBuilder) getAggregatesSql(aggregators []*queryAggregator) string { +func (qb *QueryBuilder) getAggregatesSql(aggregators []QueryAggregator) (string, error) { selectList := []string{"COUNT(*) AS count"} for _, agg := range aggregators { - sql := fmt.Sprintf("%s AS %s", agg.aggregator(agg.column), agg.column.name) + sql, err := agg.AggregateSql() + if err != nil { + return "", err + } selectList = append(selectList, sql) } - return strings.Join(selectList, ", ") + return strings.Join(selectList, ", "), nil } func (qb *QueryBuilder) groupByOrderSql(order *model.Order) (string, error) { @@ -962,23 +957,6 @@ func (qb *QueryBuilder) groupByOrderSql(order *model.Order) (string, error) { return fmt.Sprintf("ORDER BY %s %s", col, order.Direction), nil } -func getAggregatorFn(aggregateType AggregateType) (aggregatorFn, error) { - switch aggregateType { - case Max: - return func(col *queryColumn) string { - return fmt.Sprintf("MAX(%s.%s)", col.abbrev, col.name) - }, nil - case Average: - return func(col *queryColumn) string { - return fmt.Sprintf("AVG(%s.%s)", col.abbrev, col.name) - }, nil - case Unknown: - return nil, errors.New("unknown aggregate type") - default: - return nil, errors.Errorf("cannot determine aggregate type: %v", aggregateType) - } -} - func (qb *QueryBuilder) getQueryColumn(col string, queryTables map[string]bool) (*queryColumn, error) { table, err := qb.highestPrecedenceTableForColumn(col, queryTables) if err != nil { diff --git a/internal/lookoutv2/repository/common_test.go b/internal/lookoutv2/repository/querybuilder_test.go similarity index 88% rename from internal/lookoutv2/repository/common_test.go rename to internal/lookoutv2/repository/querybuilder_test.go index 3fe2dd708c5..aa15d3b82c0 100644 --- a/internal/lookoutv2/repository/common_test.go +++ b/internal/lookoutv2/repository/querybuilder_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" ) @@ -446,6 +447,64 @@ func TestQueryBuilder_GroupByMultipleAggregates(t *testing.T) { assert.Equal(t, []interface{}{"test\\queue", "1234", "abcd", "test\\queue", "5678", "efgh%", "test\\queue", "anon\\\\one%"}, query.Args) } +func TestQueryBuilder_GroupByStateAggregates(t *testing.T) { + stateFilter := &model.Filter{ + Field: "state", + Match: model.MatchAnyOf, + Value: []string{ + string(lookout.JobQueued), + string(lookout.JobLeased), + string(lookout.JobPending), + string(lookout.JobRunning), + }, + } + query, err := NewQueryBuilder(NewTables()).GroupBy( + append(testFilters, stateFilter), + &model.Order{ + Direction: "DESC", + Field: "lastTransitionTime", + }, + &model.GroupedField{ + Field: "jobSet", + }, + []string{ + "lastTransitionTime", + "submitted", + "state", + }, + 20, + 100, + ) + assert.NoError(t, err) + assert.Equal(t, splitByWhitespace(` + SELECT j.jobset, + COUNT(*) AS count, + AVG(j.last_transition_time_seconds) AS last_transition_time_seconds, + MAX(j.submitted) AS submitted, + SUM(CASE WHEN j.state = 1 THEN 1 ELSE 0 END) AS state_QUEUED, + SUM(CASE WHEN j.state = 8 THEN 1 ELSE 0 END) AS state_LEASED, + SUM(CASE WHEN j.state = 2 THEN 1 ELSE 0 END) AS state_PENDING, + SUM(CASE WHEN j.state = 3 THEN 1 ELSE 0 END) AS state_RUNNING + FROM job AS j + INNER JOIN ( + SELECT job_id + FROM user_annotation_lookup + WHERE queue = $1 AND key = $2 AND value = $3 + ) AS ual0 ON j.job_id = ual0.job_id + INNER JOIN ( + SELECT job_id + FROM user_annotation_lookup + WHERE queue = $4 AND key = $5 AND value LIKE $6 + ) AS ual1 ON j.job_id = ual1.job_id + WHERE j.queue = $7 AND j.owner LIKE $8 AND j.state IN ($9, $10, $11, $12) + GROUP BY j.jobset + ORDER BY last_transition_time_seconds DESC + LIMIT 100 OFFSET 20 + `), + splitByWhitespace(query.Sql)) + assert.Equal(t, []interface{}{"test\\queue", "1234", "abcd", "test\\queue", "5678", "efgh%", "test\\queue", "anon\\\\one%", 1, 8, 2, 3}, query.Args) +} + func TestQueryBuilder_GroupByAnnotationMultipleAggregates(t *testing.T) { query, err := NewQueryBuilder(NewTables()).GroupBy( testFilters, diff --git a/internal/lookoutv2/repository/tables.go b/internal/lookoutv2/repository/tables.go index 4633620ec31..779f53fc854 100644 --- a/internal/lookoutv2/repository/tables.go +++ b/internal/lookoutv2/repository/tables.go @@ -41,9 +41,10 @@ const ( type AggregateType int const ( - Unknown AggregateType = -1 - Max = 0 - Average = 1 + Unknown AggregateType = -1 + Max = 0 + Average = 1 + StateCounts = 2 ) type LookoutTables struct { @@ -134,6 +135,7 @@ func NewTables() *LookoutTables { groupAggregates: map[string]AggregateType{ submittedCol: Max, lastTransitionTimeCol: Average, + stateCol: StateCounts, }, } } diff --git a/internal/lookoutv2/repository/util.go b/internal/lookoutv2/repository/util.go index 2b7d8820e38..00af6da2b06 100644 --- a/internal/lookoutv2/repository/util.go +++ b/internal/lookoutv2/repository/util.go @@ -166,6 +166,30 @@ func (js *JobSimulator) Submit(queue, jobSet, owner string, timestamp time.Time, return js } +func (js *JobSimulator) Lease(runId string, timestamp time.Time) *JobSimulator { + ts := timestampOrNow(timestamp) + leasedEvent := &armadaevents.EventSequence_Event{ + Created: &ts, + Event: &armadaevents.EventSequence_Event_JobRunLeased{ + JobRunLeased: &armadaevents.JobRunLeased{ + RunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(runId)), + JobId: js.jobId, + }, + }, + } + js.events = append(js.events, leasedEvent) + + js.job.LastActiveRunId = &runId + js.job.LastTransitionTime = ts + js.job.State = string(lookout.JobLeased) + updateRun(js.job, &runPatch{ + runId: runId, + jobRunState: pointer.String(string(lookout.JobRunLeased)), + pending: &ts, + }) + return js +} + func (js *JobSimulator) Pending(runId string, cluster string, timestamp time.Time) *JobSimulator { ts := timestampOrNow(timestamp) assignedEvent := &armadaevents.EventSequence_Event{ @@ -417,6 +441,31 @@ func (js *JobSimulator) Failed(node string, exitCode int32, message string, time return js } +func (js *JobSimulator) Preempted(timestamp time.Time) *JobSimulator { + ts := timestampOrNow(timestamp) + jobIdProto, err := armadaevents.ProtoUuidFromUlidString(util.NewULID()) + if err != nil { + log.WithError(err).Errorf("Could not convert job ID to UUID: %s", util.NewULID()) + } + + preempted := &armadaevents.EventSequence_Event{ + Created: &ts, + Event: &armadaevents.EventSequence_Event_JobRunPreempted{ + JobRunPreempted: &armadaevents.JobRunPreempted{ + PreemptedJobId: js.jobId, + PreemptiveJobId: jobIdProto, + PreemptedRunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(uuid.NewString())), + PreemptiveRunId: armadaevents.ProtoUuidFromUuid(uuid.MustParse(uuid.NewString())), + }, + }, + } + js.events = append(js.events, preempted) + + js.job.LastTransitionTime = ts + js.job.State = string(lookout.JobPreempted) + return js +} + func (js *JobSimulator) RunTerminated(runId string, cluster string, node string, message string, timestamp time.Time) *JobSimulator { ts := timestampOrNow(timestamp) terminated := &armadaevents.EventSequence_Event{ diff --git a/internal/lookoutv2/swagger.yaml b/internal/lookoutv2/swagger.yaml index 1b81fffe86d..6a18a4dc1e9 100644 --- a/internal/lookoutv2/swagger.yaml +++ b/internal/lookoutv2/swagger.yaml @@ -178,7 +178,7 @@ definitions: aggregates: type: object additionalProperties: - type: string + type: object x-nullable: false filter: type: object