diff --git a/cmd/armada/main.go b/cmd/armada/main.go index 688fd78c029..a5eb8751d63 100644 --- a/cmd/armada/main.go +++ b/cmd/armada/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "net/http" _ "net/http/pprof" @@ -13,11 +12,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "github.com/spf13/viper" - "golang.org/x/sync/errgroup" "github.com/armadaproject/armada/internal/armada" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" gateway "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/logging" @@ -67,7 +66,7 @@ func main() { } // Run services within an errgroup to propagate errors between services. - g, ctx := errgroup.WithContext(context.Background()) + g, ctx := armadacontext.ErrGroup(armadacontext.Background()) // Cancel the errgroup context on SIGINT and SIGTERM, // which shuts everything down gracefully. diff --git a/cmd/eventsprinter/logic/logic.go b/cmd/eventsprinter/logic/logic.go index 34de61b4d61..b7a9dab8ea7 100644 --- a/cmd/eventsprinter/logic/logic.go +++ b/cmd/eventsprinter/logic/logic.go @@ -1,7 +1,6 @@ package logic import ( - "context" "fmt" "time" @@ -9,6 +8,7 @@ import ( "github.com/gogo/protobuf/proto" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -18,7 +18,7 @@ func PrintEvents(url, topic, subscription string, verbose bool) error { fmt.Println("URL:", url) fmt.Println("Topic:", topic) fmt.Println("Subscription", subscription) - return withSetup(url, topic, subscription, func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error { + return withSetup(url, topic, subscription, func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error { // Number of active jobs. numJobs := 0 @@ -199,7 +199,7 @@ func stripPodSpec(spec *v1.PodSpec) *v1.PodSpec { } // Run action with an Armada submit client and a Pulsar producer and consumer. -func withSetup(url, topic, subscription string, action func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error { +func withSetup(url, topic, subscription string, action func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error { pulsarClient, err := pulsar.NewClient(pulsar.ClientOptions{ URL: url, }) @@ -225,5 +225,5 @@ func withSetup(url, topic, subscription string, action func(ctx context.Context, } defer consumer.Close() - return action(context.Background(), producer, consumer) + return action(armadacontext.Background(), producer, consumer) } diff --git a/cmd/executor/main.go b/cmd/executor/main.go index ed8444fbdb4..ac6374a186c 100644 --- a/cmd/executor/main.go +++ b/cmd/executor/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "net/http" "os" "os/signal" @@ -13,6 +12,7 @@ import ( "github.com/spf13/viper" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/executor" "github.com/armadaproject/armada/internal/executor/configuration" @@ -55,7 +55,7 @@ func main() { ) defer shutdownMetricServer() - shutdown, wg := executor.StartUp(context.Background(), logrus.NewEntry(logrus.New()), config) + shutdown, wg := executor.StartUp(armadacontext.Background(), logrus.NewEntry(logrus.New()), config) go func() { <-shutdownChannel shutdown() diff --git a/cmd/lookoutv2/main.go b/cmd/lookoutv2/main.go index 3ba4a865e4d..a2d5f6be90e 100644 --- a/cmd/lookoutv2/main.go +++ b/cmd/lookoutv2/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "os" "os/signal" "syscall" @@ -12,6 +11,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/lookoutv2" "github.com/armadaproject/armada/internal/lookoutv2/configuration" @@ -36,9 +36,9 @@ func init() { pflag.Parse() } -func makeContext() (context.Context, func()) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) +func makeContext() (*armadacontext.Context, func()) { + ctx := armadacontext.Background() + ctx, cancel := armadacontext.WithCancel(ctx) c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) @@ -57,7 +57,7 @@ func makeContext() (context.Context, func()) { } } -func migrate(ctx context.Context, config configuration.LookoutV2Configuration) { +func migrate(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) { db, err := database.OpenPgxPool(config.Postgres) if err != nil { panic(err) @@ -74,7 +74,7 @@ func migrate(ctx context.Context, config configuration.LookoutV2Configuration) { } } -func prune(ctx context.Context, config configuration.LookoutV2Configuration) { +func prune(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) { db, err := database.OpenPgxConn(config.Postgres) if err != nil { panic(err) @@ -92,7 +92,7 @@ func prune(ctx context.Context, config configuration.LookoutV2Configuration) { log.Infof("expireAfter: %v, batchSize: %v, timeout: %v", config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, config.PrunerConfig.Timeout) - ctxTimeout, cancel := context.WithTimeout(ctx, config.PrunerConfig.Timeout) + ctxTimeout, cancel := armadacontext.WithTimeout(ctx, config.PrunerConfig.Timeout) defer cancel() err = pruner.PruneDb(ctxTimeout, db, config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, clock.RealClock{}) if err != nil { diff --git a/cmd/scheduler/cmd/migrate_database.go b/cmd/scheduler/cmd/migrate_database.go index 1564bffb9fd..22d6dc12dc3 100644 --- a/cmd/scheduler/cmd/migrate_database.go +++ b/cmd/scheduler/cmd/migrate_database.go @@ -1,13 +1,13 @@ package cmd import ( - "context" "time" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" ) @@ -43,7 +43,7 @@ func migrateDatabase(cmd *cobra.Command, _ []string) error { return errors.WithMessagef(err, "Failed to connect to database") } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout) defer cancel() return schedulerdb.Migrate(ctx, db) } diff --git a/internal/armada/repository/event.go b/internal/armada/repository/event.go index 2e05ba377c6..9df6d7a1a05 100644 --- a/internal/armada/repository/event.go +++ b/internal/armada/repository/event.go @@ -14,6 +14,7 @@ import ( "github.com/armadaproject/armada/internal/armada/repository/apimessages" "github.com/armadaproject/armada/internal/armada/repository/sequence" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" @@ -48,7 +49,7 @@ func NewEventRepository(db redis.UniversalClient) *RedisEventRepository { NumTestsPerEvictionRun: 10, } - decompressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( + decompressorPool := pool.NewObjectPool(armadacontext.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibDecompressor(), nil }), &poolConfig) @@ -134,16 +135,16 @@ func (repo *RedisEventRepository) GetLastMessageId(queue, jobSetId string) (stri func (repo *RedisEventRepository) extractEvents(msg redis.XMessage, queue, jobSetId string) ([]*api.EventMessage, error) { data := msg.Values[dataKey] bytes := []byte(data.(string)) - decompressor, err := repo.decompressorPool.BorrowObject(context.Background()) + decompressor, err := repo.decompressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, errors.WithStack(err) } - defer func(decompressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(decompressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := decompressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning decompressor to pool") } - }(repo.decompressorPool, context.Background(), decompressor) + }(repo.decompressorPool, armadacontext.Background(), decompressor) decompressedData, err := decompressor.(compress.Decompressor).Decompress(bytes) if err != nil { return nil, errors.WithStack(err) diff --git a/internal/armada/repository/event_store.go b/internal/armada/repository/event_store.go index 7241cba02ef..248a405b6a4 100644 --- a/internal/armada/repository/event_store.go +++ b/internal/armada/repository/event_store.go @@ -1,10 +1,9 @@ package repository import ( - "context" - "github.com/apache/pulsar-client-go/pulsar" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/common/schedulers" @@ -12,14 +11,14 @@ import ( ) type EventStore interface { - ReportEvents(context.Context, []*api.EventMessage) error + ReportEvents(*armadacontext.Context, []*api.EventMessage) error } type TestEventStore struct { ReceivedEvents []*api.EventMessage } -func (es *TestEventStore) ReportEvents(_ context.Context, message []*api.EventMessage) error { +func (es *TestEventStore) ReportEvents(_ *armadacontext.Context, message []*api.EventMessage) error { es.ReceivedEvents = append(es.ReceivedEvents, message...) return nil } @@ -35,7 +34,7 @@ func NewEventStore(producer pulsar.Producer, maxAllowedMessageSize uint) *Stream } } -func (n *StreamEventStore) ReportEvents(ctx context.Context, apiEvents []*api.EventMessage) error { +func (n *StreamEventStore) ReportEvents(ctx *armadacontext.Context, apiEvents []*api.EventMessage) error { if len(apiEvents) == 0 { return nil } diff --git a/internal/armada/scheduling/lease_manager.go b/internal/armada/scheduling/lease_manager.go index 9b34786af9c..6e1c6385f9f 100644 --- a/internal/armada/scheduling/lease_manager.go +++ b/internal/armada/scheduling/lease_manager.go @@ -1,12 +1,12 @@ package scheduling import ( - "context" "time" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" ) @@ -55,7 +55,7 @@ func (l *LeaseManager) ExpireLeases() { if e != nil { log.Error(e) } else { - e := l.eventStore.ReportEvents(context.Background(), []*api.EventMessage{event}) + e := l.eventStore.ReportEvents(armadacontext.Background(), []*api.EventMessage{event}) if e != nil { log.Error(e) } diff --git a/internal/armada/server.go b/internal/armada/server.go index e60567583bc..7f77b26b0d9 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -1,7 +1,6 @@ package armada import ( - "context" "fmt" "net" "time" @@ -13,7 +12,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/armadaproject/armada/internal/armada/cache" @@ -22,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" "github.com/armadaproject/armada/internal/armada/server" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/database" @@ -39,7 +38,7 @@ import ( "github.com/armadaproject/armada/pkg/client" ) -func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error { +func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error { log.Info("Armada server starting") log.Infof("Armada priority classes: %v", config.Scheduling.Preemption.PriorityClasses) log.Infof("Default priority class: %s", config.Scheduling.Preemption.DefaultPriorityClass) @@ -51,9 +50,9 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks // Run all services within an errgroup to propagate errors between services. // Defer cancelling the parent context to ensure the errgroup is cancelled on return. - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := armadacontext.WithCancel(ctx) defer cancel() - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) // List of services to run concurrently. // Because we want to start services only once all input validation has been completed, diff --git a/internal/armada/server/authorization.go b/internal/armada/server/authorization.go index 1d11253d3c7..434771afcbf 100644 --- a/internal/armada/server/authorization.go +++ b/internal/armada/server/authorization.go @@ -1,10 +1,10 @@ package server import ( - "context" "fmt" "strings" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/pkg/client/queue" @@ -60,7 +60,7 @@ func MergePermissionErrors(errs ...*ErrUnauthorized) *ErrUnauthorized { // permissions required to perform some action. The error returned is of type ErrUnauthorized. // After recovering the error (using errors.As), the caller can obtain the name of the user and the // requested permission programatically via this error type. -func checkPermission(p authorization.PermissionChecker, ctx context.Context, permission permission.Permission) error { +func checkPermission(p authorization.PermissionChecker, ctx *armadacontext.Context, permission permission.Permission) error { if !p.UserHasPermission(ctx, permission) { return &ErrUnauthorized{ Principal: authorization.GetPrincipal(ctx), @@ -74,7 +74,7 @@ func checkPermission(p authorization.PermissionChecker, ctx context.Context, per func checkQueuePermission( p authorization.PermissionChecker, - ctx context.Context, + ctx *armadacontext.Context, q queue.Queue, globalPermission permission.Permission, verb queue.PermissionVerb, diff --git a/internal/armada/server/event.go b/internal/armada/server/event.go index 484f1a3a9f9..14ea0d58e18 100644 --- a/internal/armada/server/event.go +++ b/internal/armada/server/event.go @@ -13,6 +13,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/repository/sequence" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -42,7 +43,8 @@ func NewEventServer( } } -func (s *EventServer) Report(ctx context.Context, message *api.EventMessage) (*types.Empty, error) { +func (s *EventServer) Report(grpcCtx context.Context, message *api.EventMessage) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[Report] error: %s", err) } @@ -50,7 +52,8 @@ func (s *EventServer) Report(ctx context.Context, message *api.EventMessage) (*t return &types.Empty{}, s.eventStore.ReportEvents(ctx, []*api.EventMessage{message}) } -func (s *EventServer) ReportMultiple(ctx context.Context, message *api.EventList) (*types.Empty, error) { +func (s *EventServer) ReportMultiple(grpcCtx context.Context, message *api.EventList) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportMultiple] error: %s", err) } @@ -116,6 +119,7 @@ func (s *EventServer) enrichPreemptedEvent(event *api.EventMessage_Preempted, jo // GetJobSetEvents streams back all events associated with a particular job set. func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Event_GetJobSetEventsServer) error { + ctx := armadacontext.FromGrpcCtx(stream.Context()) q, err := s.queueRepository.GetQueue(request.Queue) var expected *repository.ErrQueueNotFound if errors.As(err, &expected) { @@ -124,7 +128,7 @@ func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Eve return err } - err = validateUserHasWatchPermissions(stream.Context(), s.permissions, q, request.Id) + err = validateUserHasWatchPermissions(ctx, s.permissions, q, request.Id) if err != nil { return status.Errorf(codes.PermissionDenied, "[GetJobSetEvents] %s", err) } @@ -142,7 +146,7 @@ func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Eve return s.serveEventsFromRepository(request, s.eventRepository, stream) } -func (s *EventServer) Health(ctx context.Context, cont_ *types.Empty) (*api.HealthCheckResponse, error) { +func (s *EventServer) Health(_ context.Context, _ *types.Empty) (*api.HealthCheckResponse, error) { return &api.HealthCheckResponse{Status: api.HealthCheckResponse_SERVING}, nil } @@ -222,7 +226,7 @@ func (s *EventServer) serveEventsFromRepository(request *api.JobSetRequest, even } } -func validateUserHasWatchPermissions(ctx context.Context, permsChecker authorization.PermissionChecker, q queue.Queue, jobSetId string) error { +func validateUserHasWatchPermissions(ctx *armadacontext.Context, permsChecker authorization.PermissionChecker, q queue.Queue, jobSetId string) error { err := checkPermission(permsChecker, ctx, permissions.WatchAllEvents) var globalPermErr *ErrUnauthorized if errors.As(err, &globalPermErr) { diff --git a/internal/armada/server/event_test.go b/internal/armada/server/event_test.go index a31f24965dc..18d77478f1c 100644 --- a/internal/armada/server/event_test.go +++ b/internal/armada/server/event_test.go @@ -18,6 +18,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/internal/common/compress" @@ -30,7 +31,7 @@ func TestEventServer_Health(t *testing.T) { withEventServer( t, func(s *EventServer) { - health, err := s.Health(context.Background(), &types.Empty{}) + health, err := s.Health(armadacontext.Background(), &types.Empty{}) assert.Equal(t, health.Status, api.HealthCheckResponse_SERVING) require.NoError(t, err) }, @@ -274,7 +275,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -298,7 +299,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-all-events-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -322,7 +323,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-queue-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -344,7 +345,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-events-group", "watch-queue-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -426,7 +427,7 @@ func (s *eventStreamMock) Send(m *api.EventStreamMessage) error { func (s *eventStreamMock) Context() context.Context { if s.ctx == nil { - return context.Background() + return armadacontext.Background() } return s.ctx } diff --git a/internal/armada/server/eventsprinter.go b/internal/armada/server/eventsprinter.go index 90bbca97f83..d2ba150d6e4 100644 --- a/internal/armada/server/eventsprinter.go +++ b/internal/armada/server/eventsprinter.go @@ -9,6 +9,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/pulsarutils/pulsarrequestid" @@ -29,7 +30,7 @@ type EventsPrinter struct { } // Run the service that reads from Pulsar and updates Armada until the provided context is cancelled. -func (srv *EventsPrinter) Run(ctx context.Context) error { +func (srv *EventsPrinter) Run(ctx *armadacontext.Context) error { // Get the configured logger, or the standard logger if none is provided. var log *logrus.Entry if srv.Logger != nil { @@ -74,7 +75,7 @@ func (srv *EventsPrinter) Run(ctx context.Context) error { default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, 10*time.Second) msg, err := consumer.Receive(ctxWithTimeout) cancel() if errors.Is(err, context.DeadlineExceeded) { // expected @@ -85,7 +86,7 @@ func (srv *EventsPrinter) Run(ctx context.Context) error { break } util.RetryUntilSuccess( - context.Background(), + armadacontext.Background(), func() error { return consumer.Ack(msg) }, func(err error) { logging.WithStacktrace(log, err).Warnf("acking pulsar message failed") diff --git a/internal/armada/server/lease.go b/internal/armada/server/lease.go index 7d1d7c2abec..9a776d0e15f 100644 --- a/internal/armada/server/lease.go +++ b/internal/armada/server/lease.go @@ -10,11 +10,9 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/gogo/protobuf/types" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-multierror" pool "github.com/jolestar/go-commons-pool" "github.com/pkg/errors" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" @@ -27,6 +25,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/compress" @@ -97,6 +96,7 @@ func NewAggregatedQueueServer( TimeBetweenEvictionRuns: 0, NumTestsPerEvictionRun: 10, } + decompressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibDecompressor(), nil @@ -128,7 +128,7 @@ func NewAggregatedQueueServer( // // This function should be used instead of the LeaseJobs function in most cases. func (q *AggregatedQueueServer) StreamingLeaseJobs(stream api.AggregatedQueue_StreamingLeaseJobsServer) error { - if err := checkPermission(q.permissions, stream.Context(), permissions.ExecuteJobs); err != nil { + if err := checkPermission(q.permissions, armadacontext.FromGrpcCtx(stream.Context()), permissions.ExecuteJobs); err != nil { return err } @@ -151,7 +151,7 @@ func (q *AggregatedQueueServer) StreamingLeaseJobs(stream api.AggregatedQueue_St } // Get jobs to be leased. - jobs, err := q.getJobs(stream.Context(), req) + jobs, err := q.getJobs(armadacontext.FromGrpcCtx(stream.Context()), req) if err != nil { return err } @@ -262,14 +262,12 @@ func (repo *SchedulerJobRepositoryAdapter) GetExistingJobsByIds(ids []string) ([ return rv, nil } -func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingLeaseRequest) ([]*api.Job, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithFields(logrus.Fields{ - "function": "getJobs", - "cluster": req.ClusterId, - "pool": req.Pool, - }) - ctx = ctxlogrus.ToContext(ctx, log) +func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.StreamingLeaseRequest) ([]*api.Job, error) { + ctx = armadacontext. + WithLogFields(ctx, map[string]interface{}{ + "cluster": req.ClusterId, + "pool": req.Pool, + }) // Get the total capacity available across all clusters. usageReports, err := q.usageRepository.GetClusterUsageReports() @@ -346,7 +344,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL lastSeen, ) if err != nil { - logging.WithStacktrace(log, err).Warnf( + logging.WithStacktrace(ctx.Log, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetClusterId(), ) continue @@ -474,7 +472,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL // Give Schedule() a 3 second shorter deadline than ctx to give it a chance to finish up before ctx deadline. if deadline, ok := ctx.Deadline(); ok { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, deadline.Add(-3*time.Second)) + ctx, cancel = armadacontext.WithDeadline(ctx, deadline.Add(-3*time.Second)) defer cancel() } @@ -558,12 +556,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL "starting scheduling with total resources %s", schedulerobjects.ResourceList{Resources: totalCapacity}.CompactString(), ) - result, err := sch.Schedule( - ctxlogrus.ToContext( - ctx, - logrus.NewEntry(logrus.New()), - ), - ) + result, err := sch.Schedule(ctx) if err != nil { return nil, err } @@ -573,7 +566,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL if q.SchedulingContextRepository != nil { sctx.ClearJobSpecs() if err := q.SchedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(log, err).Error("failed to store scheduling context") + logging.WithStacktrace(ctx.Log, err).Error("failed to store scheduling context") } } @@ -648,7 +641,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL jobIdsToDelete := util.Map(jobsToDelete, func(job *api.Job) string { return job.Id }) log.Infof("deleting preempted jobs: %v", jobIdsToDelete) if deletionResult, err := q.jobRepository.DeleteJobs(jobsToDelete); err != nil { - logging.WithStacktrace(log, err).Error("failed to delete preempted jobs from Redis") + logging.WithStacktrace(ctx.Log, err).Error("failed to delete preempted jobs from Redis") } else { deleteErrorByJobId := armadamaps.MapKeys(deletionResult, func(job *api.Job) string { return job.Id }) for jobId := range preemptedApiJobsById { @@ -711,7 +704,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } } if err := q.usageRepository.UpdateClusterQueueResourceUsage(req.ClusterId, currentExecutorReport); err != nil { - logging.WithStacktrace(log, err).Errorf("failed to update cluster usage") + logging.WithStacktrace(ctx.Log, err).Errorf("failed to update cluster usage") } allocatedByQueueAndPriorityClassForPool = q.aggregateAllocationAcrossExecutor(reportsByExecutor, req.Pool) @@ -735,7 +728,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(log, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) continue } v := node.Labels[q.schedulingConfig.Preemption.NodeIdLabel] @@ -771,7 +764,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(log, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) continue } podSpec.NodeName = node.Name @@ -880,22 +873,23 @@ func (q *AggregatedQueueServer) decompressJobOwnershipGroups(jobs []*api.Job) er } func (q *AggregatedQueueServer) decompressOwnershipGroups(compressedOwnershipGroups []byte) ([]string, error) { - decompressor, err := q.decompressorPool.BorrowObject(context.Background()) + decompressor, err := q.decompressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, fmt.Errorf("failed to borrow decompressior because %s", err) } - defer func(decompressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(decompressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := decompressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning decompressorPool to pool") } - }(q.decompressorPool, context.Background(), decompressor) + }(q.decompressorPool, armadacontext.Background(), decompressor) return compress.DecompressStringArray(compressedOwnershipGroups, decompressor.(compress.Decompressor)) } -func (q *AggregatedQueueServer) RenewLease(ctx context.Context, request *api.RenewLeaseRequest) (*api.IdList, error) { +func (q *AggregatedQueueServer) RenewLease(grpcCtx context.Context, request *api.RenewLeaseRequest) (*api.IdList, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, err.Error()) } @@ -903,7 +897,8 @@ func (q *AggregatedQueueServer) RenewLease(ctx context.Context, request *api.Ren return &api.IdList{Ids: renewed}, e } -func (q *AggregatedQueueServer) ReturnLease(ctx context.Context, request *api.ReturnLeaseRequest) (*types.Empty, error) { +func (q *AggregatedQueueServer) ReturnLease(grpcCtx context.Context, request *api.ReturnLeaseRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, err.Error()) } @@ -1002,7 +997,8 @@ func (q *AggregatedQueueServer) addAvoidNodeAffinity( return res[0].Error } -func (q *AggregatedQueueServer) ReportDone(ctx context.Context, idList *api.IdList) (*api.IdList, error) { +func (q *AggregatedQueueServer) ReportDone(grpcCtx context.Context, idList *api.IdList) (*api.IdList, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportDone] error: %s", err) } @@ -1027,7 +1023,7 @@ func (q *AggregatedQueueServer) ReportDone(ctx context.Context, idList *api.IdLi return &api.IdList{Ids: cleanedIds}, returnedError } -func (q *AggregatedQueueServer) reportLeaseReturned(ctx context.Context, leaseReturnRequest *api.ReturnLeaseRequest) error { +func (q *AggregatedQueueServer) reportLeaseReturned(ctx *armadacontext.Context, leaseReturnRequest *api.ReturnLeaseRequest) error { job, err := q.getJobById(leaseReturnRequest.JobId) if err != nil { return err diff --git a/internal/armada/server/lease_test.go b/internal/armada/server/lease_test.go index 7f3f8470491..554282c546a 100644 --- a/internal/armada/server/lease_test.go +++ b/internal/armada/server/lease_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -25,7 +25,7 @@ func TestAggregatedQueueServer_ReturnLeaseCallsRepositoryMethod(t *testing.T) { _, addJobsErr := mockJobRepository.AddJobs([]*api.Job{job}) assert.Nil(t, addJobsErr) - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -54,7 +54,7 @@ func TestAggregatedQueueServer_ReturnLeaseCallsSendsJobLeaseReturnedEvent(t *tes _, addJobsErr := mockJobRepository.AddJobs([]*api.Job{job}) assert.Nil(t, addJobsErr) - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, Reason: reason, @@ -84,7 +84,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesDeletesJob(t *tes assert.Nil(t, addJobsErr) for i := 0; i < maxRetries; i++ { - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -96,7 +96,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesDeletesJob(t *tes assert.Equal(t, jobId, mockJobRepository.returnLeaseArg2) } - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -125,7 +125,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesSendsJobFailedEve assert.Nil(t, addJobsErr) for i := 0; i < maxRetries; i++ { - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -136,7 +136,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesSendsJobFailedEve fakeEventStore.events = []*api.EventMessage{} } - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -169,7 +169,7 @@ func TestAggregatedQueueServer_ReturningLease_IncrementsRetries(t *testing.T) { assert.Nil(t, addJobsErr) // Does not count towards retries if JobRunAttempted is false - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: false, @@ -180,7 +180,7 @@ func TestAggregatedQueueServer_ReturningLease_IncrementsRetries(t *testing.T) { assert.Equal(t, 0, numberOfRetries) // Does count towards reties if JobRunAttempted is true - _, err = aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err = aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -452,7 +452,7 @@ type fakeEventStore struct { events []*api.EventMessage } -func (es *fakeEventStore) ReportEvents(_ context.Context, message []*api.EventMessage) error { +func (es *fakeEventStore) ReportEvents(_ *armadacontext.Context, message []*api.EventMessage) error { es.events = append(es.events, message...) return nil } @@ -469,14 +469,14 @@ func (repo *fakeSchedulingInfoRepository) UpdateClusterSchedulingInfo(report *ap type fakeExecutorRepository struct{} -func (f fakeExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (f fakeExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { return nil, nil } -func (f fakeExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (f fakeExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { return nil, nil } -func (f fakeExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (f fakeExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { return nil } diff --git a/internal/armada/server/reporting.go b/internal/armada/server/reporting.go index d3a5eae180b..73afc3d3c17 100644 --- a/internal/armada/server/reporting.go +++ b/internal/armada/server/reporting.go @@ -1,13 +1,13 @@ package server import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" ) @@ -27,7 +27,7 @@ func reportQueued(repository repository.EventStore, jobs []*api.Job) error { events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportQueued] error reporting events: %w", err) } @@ -52,7 +52,7 @@ func reportDuplicateDetected(repository repository.EventStore, results []*reposi events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportDuplicateDetected] error reporting events: %w", err) } @@ -77,7 +77,7 @@ func reportSubmitted(repository repository.EventStore, jobs []*api.Job) error { events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportSubmitted] error reporting events: %w", err) } @@ -106,7 +106,7 @@ func reportJobsLeased(repository repository.EventStore, jobs []*api.Job, cluster } } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { err = fmt.Errorf("[reportJobsLeased] error reporting events: %w", err) log.Error(err) @@ -128,7 +128,7 @@ func reportJobLeaseReturned(repository repository.EventStore, job *api.Job, leas return fmt.Errorf("error wrapping event: %w", err) } - err = repository.ReportEvents(context.Background(), []*api.EventMessage{event}) + err = repository.ReportEvents(armadacontext.Background(), []*api.EventMessage{event}) if err != nil { return fmt.Errorf("error reporting lease returned event: %w", err) } @@ -154,7 +154,7 @@ func reportJobsCancelling(repository repository.EventStore, requestorName string events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsCancelling] error reporting events: %w", err) } @@ -180,7 +180,7 @@ func reportJobsReprioritizing(repository repository.EventStore, requestorName st events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsReprioritizing] error reporting events: %w", err) } @@ -206,7 +206,7 @@ func reportJobsReprioritized(repository repository.EventStore, requestorName str events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsReprioritized] error reporting events: %w", err) } @@ -232,7 +232,7 @@ func reportJobsUpdated(repository repository.EventStore, requestorName string, j events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsUpdated] error reporting events: %w", err) } @@ -259,7 +259,7 @@ func reportJobsCancelled(repository repository.EventStore, requestorName string, events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsCancelled] error reporting events: %w", err) } @@ -293,7 +293,7 @@ func reportFailed(repository repository.EventStore, clusterId string, jobFailure events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportFailed] error reporting events: %w", err) } diff --git a/internal/armada/server/submit.go b/internal/armada/server/submit.go index ca444ff3099..c129fbb1da1 100644 --- a/internal/armada/server/submit.go +++ b/internal/armada/server/submit.go @@ -20,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" servervalidation "github.com/armadaproject/armada/internal/armada/validation" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/compress" @@ -62,7 +63,7 @@ func NewSubmitServer( NumTestsPerEvictionRun: 10, } - compressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( + compressorPool := pool.NewObjectPool(armadacontext.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibCompressor(512) }), &poolConfig) @@ -85,7 +86,8 @@ func (server *SubmitServer) Health(ctx context.Context, _ *types.Empty) (*api.He return &api.HealthCheckResponse{Status: api.HealthCheckResponse_SERVING}, nil } -func (server *SubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueInfoRequest) (*api.QueueInfo, error) { +func (server *SubmitServer) GetQueueInfo(grpcCtx context.Context, req *api.QueueInfoRequest) (*api.QueueInfo, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) q, err := server.queueRepository.GetQueue(req.Name) var expected *repository.ErrQueueNotFound if errors.Is(err, expected) { @@ -121,7 +123,7 @@ func (server *SubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueInfo }, nil } -func (server *SubmitServer) GetQueue(ctx context.Context, req *api.QueueGetRequest) (*api.Queue, error) { +func (server *SubmitServer) GetQueue(grpcCtx context.Context, req *api.QueueGetRequest) (*api.Queue, error) { queue, err := server.queueRepository.GetQueue(req.Name) var e *repository.ErrQueueNotFound if errors.As(err, &e) { @@ -132,7 +134,8 @@ func (server *SubmitServer) GetQueue(ctx context.Context, req *api.QueueGetReque return queue.ToAPI(), nil } -func (server *SubmitServer) CreateQueue(ctx context.Context, request *api.Queue) (*types.Empty, error) { +func (server *SubmitServer) CreateQueue(grpcCtx context.Context, request *api.Queue) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.CreateQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -162,9 +165,9 @@ func (server *SubmitServer) CreateQueue(ctx context.Context, request *api.Queue) return &types.Empty{}, nil } -func (server *SubmitServer) CreateQueues(ctx context.Context, request *api.QueueList) (*api.BatchQueueCreateResponse, error) { +func (server *SubmitServer) CreateQueues(grpcCtx context.Context, request *api.QueueList) (*api.BatchQueueCreateResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var failedQueues []*api.QueueCreateResponse - // Create a queue for each element of the request body and return the failures. for _, queue := range request.Queues { _, err := server.CreateQueue(ctx, queue) @@ -181,7 +184,8 @@ func (server *SubmitServer) CreateQueues(ctx context.Context, request *api.Queue }, nil } -func (server *SubmitServer) UpdateQueue(ctx context.Context, request *api.Queue) (*types.Empty, error) { +func (server *SubmitServer) UpdateQueue(grpcCtx context.Context, request *api.Queue) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.CreateQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -206,7 +210,8 @@ func (server *SubmitServer) UpdateQueue(ctx context.Context, request *api.Queue) return &types.Empty{}, nil } -func (server *SubmitServer) UpdateQueues(ctx context.Context, request *api.QueueList) (*api.BatchQueueUpdateResponse, error) { +func (server *SubmitServer) UpdateQueues(grpcCtx context.Context, request *api.QueueList) (*api.BatchQueueUpdateResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var failedQueues []*api.QueueUpdateResponse // Create a queue for each element of the request body and return the failures. @@ -225,7 +230,8 @@ func (server *SubmitServer) UpdateQueues(ctx context.Context, request *api.Queue }, nil } -func (server *SubmitServer) DeleteQueue(ctx context.Context, request *api.QueueDeleteRequest) (*types.Empty, error) { +func (server *SubmitServer) DeleteQueue(grpcCtx context.Context, request *api.QueueDeleteRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.DeleteQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -250,7 +256,8 @@ func (server *SubmitServer) DeleteQueue(ctx context.Context, request *api.QueueD return &types.Empty{}, nil } -func (server *SubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { +func (server *SubmitServer) SubmitJobs(grpcCtx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) principal := authorization.GetPrincipal(ctx) jobs, e := server.createJobs(req, principal.GetName(), principal.GetGroupNames()) @@ -404,7 +411,8 @@ func (server *SubmitServer) countQueuedJobs(q queue.Queue) (int64, error) { // CancelJobs cancels jobs identified by the request. // If the request contains a job ID, only the job with that ID is cancelled. // If the request contains a queue name and a job set ID, all jobs matching those are cancelled. -func (server *SubmitServer) CancelJobs(ctx context.Context, request *api.JobCancelRequest) (*api.CancellationResult, error) { +func (server *SubmitServer) CancelJobs(grpcCtx context.Context, request *api.JobCancelRequest) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if request.JobId != "" { return server.cancelJobsById(ctx, request.JobId, request.Reason) } else if request.JobSetId != "" && request.Queue != "" { @@ -413,7 +421,8 @@ func (server *SubmitServer) CancelJobs(ctx context.Context, request *api.JobCanc return nil, status.Errorf(codes.InvalidArgument, "[CancelJobs] specify either job ID or both queue name and job set ID") } -func (server *SubmitServer) CancelJobSet(ctx context.Context, request *api.JobSetCancelRequest) (*types.Empty, error) { +func (server *SubmitServer) CancelJobSet(grpcCtx context.Context, request *api.JobSetCancelRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := servervalidation.ValidateJobSetFilter(request.Filter) if err != nil { return nil, err @@ -444,7 +453,7 @@ func createJobSetFilter(filter *api.JobSetFilter) *repository.JobSetFilter { } // cancels a job with a given ID -func (server *SubmitServer) cancelJobsById(ctx context.Context, jobId string, reason string) (*api.CancellationResult, error) { +func (server *SubmitServer) cancelJobsById(ctx *armadacontext.Context, jobId string, reason string) (*api.CancellationResult, error) { jobs, err := server.jobRepository.GetExistingJobsByIds([]string{jobId}) if err != nil { return nil, status.Errorf(codes.Unavailable, "[cancelJobsById] error getting job with ID %s: %s", jobId, err) @@ -466,7 +475,7 @@ func (server *SubmitServer) cancelJobsById(ctx context.Context, jobId string, re // cancels all jobs part of a particular job set and queue func (server *SubmitServer) cancelJobsByQueueAndSet( - ctx context.Context, + ctx *armadacontext.Context, queue string, jobSetId string, filter *repository.JobSetFilter, @@ -509,7 +518,7 @@ func (server *SubmitServer) cancelJobsByQueueAndSet( return &api.CancellationResult{CancelledIds: cancelledIds}, nil } -func (server *SubmitServer) cancelJobs(ctx context.Context, jobs []*api.Job, reason string) (*api.CancellationResult, error) { +func (server *SubmitServer) cancelJobs(ctx *armadacontext.Context, jobs []*api.Job, reason string) (*api.CancellationResult, error) { principal := authorization.GetPrincipal(ctx) err := server.checkCancelPerms(ctx, jobs) @@ -551,7 +560,7 @@ func (server *SubmitServer) cancelJobs(ctx context.Context, jobs []*api.Job, rea return &api.CancellationResult{CancelledIds: cancelledIds}, nil } -func (server *SubmitServer) checkCancelPerms(ctx context.Context, jobs []*api.Job) error { +func (server *SubmitServer) checkCancelPerms(ctx *armadacontext.Context, jobs []*api.Job) error { queueNames := make(map[string]struct{}) for _, job := range jobs { queueNames[job.Queue] = struct{}{} @@ -581,7 +590,8 @@ func (server *SubmitServer) checkCancelPerms(ctx context.Context, jobs []*api.Jo // ReprioritizeJobs updates the priority of one of more jobs. // Returns a map from job ID to any error (or nil if the call succeeded). -func (server *SubmitServer) ReprioritizeJobs(ctx context.Context, request *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { +func (server *SubmitServer) ReprioritizeJobs(grpcCtx context.Context, request *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var jobs []*api.Job if len(request.JobIds) > 0 { existingJobs, err := server.jobRepository.GetExistingJobsByIds(request.JobIds) @@ -674,7 +684,7 @@ func (server *SubmitServer) reportReprioritizedJobEvents(reprioritizedJobs []*ap return nil } -func (server *SubmitServer) checkReprioritizePerms(ctx context.Context, jobs []*api.Job) error { +func (server *SubmitServer) checkReprioritizePerms(ctx *armadacontext.Context, jobs []*api.Job) error { queueNames := make(map[string]struct{}) for _, job := range jobs { queueNames[job.Queue] = struct{}{} @@ -702,7 +712,7 @@ func (server *SubmitServer) checkReprioritizePerms(ctx context.Context, jobs []* return nil } -func (server *SubmitServer) getQueueOrCreate(ctx context.Context, queueName string) (*queue.Queue, error) { +func (server *SubmitServer) getQueueOrCreate(ctx *armadacontext.Context, queueName string) (*queue.Queue, error) { q, e := server.queueRepository.GetQueue(queueName) if e == nil { return &q, nil @@ -753,16 +763,16 @@ func (server *SubmitServer) createJobs(request *api.JobSubmitRequest, owner stri func (server *SubmitServer) createJobsObjects(request *api.JobSubmitRequest, owner string, ownershipGroups []string, getTime func() time.Time, getUlid func() string, ) ([]*api.Job, error) { - compressor, err := server.compressorPool.BorrowObject(context.Background()) + compressor, err := server.compressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, err } - defer func(compressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(compressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := compressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning compressor to pool") } - }(server.compressorPool, context.Background(), compressor) + }(server.compressorPool, armadacontext.Background(), compressor) compressedOwnershipGroups, err := compress.CompressStringArray(ownershipGroups, compressor.(compress.Compressor)) if err != nil { return nil, err diff --git a/internal/armada/server/submit_from_log.go b/internal/armada/server/submit_from_log.go index 13acbf9904a..90b5ece3553 100644 --- a/internal/armada/server/submit_from_log.go +++ b/internal/armada/server/submit_from_log.go @@ -7,19 +7,17 @@ import ( "time" "github.com/apache/pulsar-client-go/pulsar" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-multierror" pool "github.com/jolestar/go-commons-pool" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/logging" - "github.com/armadaproject/armada/internal/common/pulsarutils/pulsarrequestid" - "github.com/armadaproject/armada/internal/common/requestid" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api" @@ -38,7 +36,7 @@ type SubmitFromLog struct { } // Run the service that reads from Pulsar and updates Armada until the provided context is cancelled. -func (srv *SubmitFromLog) Run(ctx context.Context) error { +func (srv *SubmitFromLog) Run(ctx *armadacontext.Context) error { // Get the configured logger, or the standard logger if none is provided. log := srv.getLogger() log.Info("service started") @@ -95,7 +93,7 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, 10*time.Second) msg, err := srv.Consumer.Receive(ctxWithTimeout) cancel() if errors.Is(err, context.DeadlineExceeded) { @@ -121,29 +119,18 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { lastPublishTime = msg.PublishTime() numReceived++ - // Incoming gRPC requests are annotated with a unique id, - // which is included with the corresponding Pulsar message. - requestId := pulsarrequestid.FromMessageOrMissing(msg) - - // Put the requestId into a message-specific context and logger, - // which are passed on to sub-functions. - messageCtx, ok := requestid.AddToIncomingContext(ctx, requestId) - if !ok { - messageCtx = ctx - } - messageLogger := log.WithFields(logrus.Fields{"messageId": msg.ID(), requestid.MetadataKey: requestId}) - ctxWithLogger := ctxlogrus.ToContext(messageCtx, messageLogger) + ctxWithLogger := armadacontext.WithLogField(ctx, "messageId", msg.ID()) // Unmarshal and validate the message. sequence, err := eventutil.UnmarshalEventSequence(ctxWithLogger, msg.Payload()) if err != nil { srv.ack(ctx, msg) - logging.WithStacktrace(messageLogger, err).Warnf("processing message failed; ignoring") + logging.WithStacktrace(ctxWithLogger.Log, err).Warnf("processing message failed; ignoring") numErrored++ break } - messageLogger.WithField("numEvents", len(sequence.Events)).Info("processing sequence") + ctxWithLogger.Log.WithField("numEvents", len(sequence.Events)).Info("processing sequence") // TODO: Improve retry logic. srv.ProcessSequence(ctxWithLogger, sequence) srv.ack(ctx, msg) @@ -155,9 +142,7 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { // For efficiency, we may process several events at a time. // To maintain ordering, we only do so for subsequences of consecutive events of equal type. // The returned bool indicates if the corresponding Pulsar message should be ack'd or not. -func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadaevents.EventSequence) bool { - log := ctxlogrus.Extract(ctx) - +func (srv *SubmitFromLog) ProcessSequence(ctx *armadacontext.Context, sequence *armadaevents.EventSequence) bool { // Sub-functions should always increment the events index unless they experience a transient error. // However, if a permanent error is mis-categorised as transient, we may get stuck forever. // To avoid that issue, we return immediately if timeout time has passed @@ -170,11 +155,11 @@ func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadae for i < len(sequence.Events) && time.Since(lastProgress) < timeout { j, err := srv.ProcessSubSequence(ctx, i, sequence) if err != nil { - logging.WithStacktrace(log, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") + logging.WithStacktrace(ctx.Log, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") } if j == i { - log.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") + ctx.Log.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") // We should only get here if a transient error occurs. // Sleep for a bit before retrying. @@ -200,7 +185,7 @@ func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadae // Events are processed by calling into the embedded srv.SubmitServer. // // Not all events are handled by this processor since the legacy scheduler writes some transitions directly to the db. -func (srv *SubmitFromLog) ProcessSubSequence(ctx context.Context, i int, sequence *armadaevents.EventSequence) (j int, err error) { +func (srv *SubmitFromLog) ProcessSubSequence(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) (j int, err error) { j = i // Initially, the next event to be processed is i. if i < 0 || i >= len(sequence.Events) { err = &armadaerrors.ErrInvalidArgument{ @@ -272,7 +257,7 @@ func (srv *SubmitFromLog) ProcessSubSequence(ctx context.Context, i int, sequenc // collectJobSubmitEvents (and the corresponding functions for other types below) // return a slice of events starting at index i in the sequence with equal type. -func collectJobSubmitEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.SubmitJob { +func collectJobSubmitEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.SubmitJob { result := make([]*armadaevents.SubmitJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_SubmitJob); ok { @@ -284,7 +269,7 @@ func collectJobSubmitEvents(ctx context.Context, i int, sequence *armadaevents.E return result } -func collectCancelJobEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJob { +func collectCancelJobEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJob { result := make([]*armadaevents.CancelJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_CancelJob); ok { @@ -296,7 +281,7 @@ func collectCancelJobEvents(ctx context.Context, i int, sequence *armadaevents.E return result } -func collectCancelJobSetEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJobSet { +func collectCancelJobSetEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJobSet { result := make([]*armadaevents.CancelJobSet, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_CancelJobSet); ok { @@ -308,7 +293,7 @@ func collectCancelJobSetEvents(ctx context.Context, i int, sequence *armadaevent return result } -func collectReprioritiseJobEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJob { +func collectReprioritiseJobEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJob { result := make([]*armadaevents.ReprioritiseJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_ReprioritiseJob); ok { @@ -320,7 +305,7 @@ func collectReprioritiseJobEvents(ctx context.Context, i int, sequence *armadaev return result } -func collectReprioritiseJobSetEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJobSet { +func collectReprioritiseJobSetEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJobSet { result := make([]*armadaevents.ReprioritiseJobSet, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_ReprioritiseJobSet); ok { @@ -332,7 +317,7 @@ func collectReprioritiseJobSetEvents(ctx context.Context, i int, sequence *armad return result } -func collectEvents[T any](ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.EventSequence_Event { +func collectEvents[T any](ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.EventSequence_Event { events := make([]*armadaevents.EventSequence_Event, 0) for j := i; j < len(sequence.Events); j++ { if _, ok := sequence.Events[j].Event.(T); ok { @@ -359,7 +344,7 @@ func (srv *SubmitFromLog) getLogger() *logrus.Entry { // Specifically, events are not processed if writing to the database results in a network-related error. // For any other error, the jobs are marked as failed and the events are considered to have been processed. func (srv *SubmitFromLog) SubmitJobs( - ctx context.Context, + ctx *armadacontext.Context, userId string, groups []string, queueName string, @@ -376,16 +361,16 @@ func (srv *SubmitFromLog) SubmitJobs( } log := srv.getLogger() - compressor, err := srv.SubmitServer.compressorPool.BorrowObject(context.Background()) + compressor, err := srv.SubmitServer.compressorPool.BorrowObject(armadacontext.Background()) if err != nil { return false, err } - defer func(compressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(compressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := compressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning compressor to pool") } - }(srv.SubmitServer.compressorPool, context.Background(), compressor) + }(srv.SubmitServer.compressorPool, armadacontext.Background(), compressor) compressedOwnershipGroups, err := compress.CompressStringArray(groups, compressor.(compress.Compressor)) if err != nil { @@ -455,7 +440,7 @@ type CancelJobPayload struct { } // CancelJobs cancels all jobs specified by the provided events in a single operation. -func (srv *SubmitFromLog) CancelJobs(ctx context.Context, userId string, es []*armadaevents.CancelJob) (bool, error) { +func (srv *SubmitFromLog) CancelJobs(ctx *armadacontext.Context, userId string, es []*armadaevents.CancelJob) (bool, error) { cancelJobPayloads := make([]*CancelJobPayload, len(es)) for i, e := range es { id, err := armadaevents.UlidStringFromProtoUuid(e.JobId) @@ -475,7 +460,7 @@ func (srv *SubmitFromLog) CancelJobs(ctx context.Context, userId string, es []*a // Because event sequences are specific to queue and job set, all CancelJobSet events in a sequence are equivalent, // and we only need to call CancelJobSet once. func (srv *SubmitFromLog) CancelJobSets( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -489,7 +474,7 @@ func (srv *SubmitFromLog) CancelJobSets( return srv.CancelJobSet(ctx, userId, queueName, jobSetName, reason) } -func (srv *SubmitFromLog) CancelJobSet(ctx context.Context, userId string, queueName string, jobSetName string, reason string) (bool, error) { +func (srv *SubmitFromLog) CancelJobSet(ctx *armadacontext.Context, userId string, queueName string, jobSetName string, reason string) (bool, error) { jobIds, err := srv.SubmitServer.jobRepository.GetActiveJobIds(queueName, jobSetName) if armadaerrors.IsNetworkError(err) { return false, err @@ -505,7 +490,7 @@ func (srv *SubmitFromLog) CancelJobSet(ctx context.Context, userId string, queue return srv.BatchedCancelJobsById(ctx, userId, cancelJobPayloads) } -func (srv *SubmitFromLog) BatchedCancelJobsById(ctx context.Context, userId string, cancelJobPayloads []*CancelJobPayload) (bool, error) { +func (srv *SubmitFromLog) BatchedCancelJobsById(ctx *armadacontext.Context, userId string, cancelJobPayloads []*CancelJobPayload) (bool, error) { // Split IDs into batches and process one batch at a time. // To reduce the number of jobs stored in memory. // @@ -538,7 +523,7 @@ type CancelledJobPayload struct { } // CancelJobsById cancels all jobs with the specified ids. -func (srv *SubmitFromLog) CancelJobsById(ctx context.Context, userId string, cancelJobPayloads []*CancelJobPayload) ([]string, error) { +func (srv *SubmitFromLog) CancelJobsById(ctx *armadacontext.Context, userId string, cancelJobPayloads []*CancelJobPayload) ([]string, error) { jobIdReasonMap := make(map[string]string) jobIds := util.Map(cancelJobPayloads, func(payload *CancelJobPayload) string { jobIdReasonMap[payload.JobId] = payload.Reason @@ -588,7 +573,7 @@ func (srv *SubmitFromLog) CancelJobsById(ctx context.Context, userId string, can } // ReprioritizeJobs updates the priority of one of more jobs. -func (srv *SubmitFromLog) ReprioritizeJobs(ctx context.Context, userId string, es []*armadaevents.ReprioritiseJob) (bool, error) { +func (srv *SubmitFromLog) ReprioritizeJobs(ctx *armadacontext.Context, userId string, es []*armadaevents.ReprioritiseJob) (bool, error) { if len(es) == 0 { return true, nil } @@ -635,7 +620,7 @@ func (srv *SubmitFromLog) ReprioritizeJobs(ctx context.Context, userId string, e return true, nil } -func (srv *SubmitFromLog) DeleteFailedJobs(ctx context.Context, es []*armadaevents.EventSequence_Event) (bool, error) { +func (srv *SubmitFromLog) DeleteFailedJobs(ctx *armadacontext.Context, es []*armadaevents.EventSequence_Event) (bool, error) { jobIdsToDelete := make([]string, 0, len(es)) for _, event := range es { jobErrors := event.GetJobErrors() @@ -664,7 +649,7 @@ func (srv *SubmitFromLog) DeleteFailedJobs(ctx context.Context, es []*armadaeven } // UpdateJobStartTimes records the start time (in Redis) of one of more jobs. -func (srv *SubmitFromLog) UpdateJobStartTimes(ctx context.Context, es []*armadaevents.EventSequence_Event) (bool, error) { +func (srv *SubmitFromLog) UpdateJobStartTimes(ctx *armadacontext.Context, es []*armadaevents.EventSequence_Event) (bool, error) { jobStartsInfos := make([]*repository.JobStartInfo, 0, len(es)) for _, event := range es { jobRun := event.GetJobRunRunning() @@ -713,7 +698,7 @@ func (srv *SubmitFromLog) UpdateJobStartTimes(ctx context.Context, es []*armadae // Since repeating this operation is safe (setting the priority is idempotent), // the bool indicating if events were processed is set to false if any job set failed. func (srv *SubmitFromLog) ReprioritizeJobSets( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -730,7 +715,7 @@ func (srv *SubmitFromLog) ReprioritizeJobSets( } func (srv *SubmitFromLog) ReprioritizeJobSet( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -767,7 +752,7 @@ func (srv *SubmitFromLog) ReprioritizeJobSet( return true, nil } -func (srv *SubmitFromLog) ack(ctx context.Context, msg pulsar.Message) { +func (srv *SubmitFromLog) ack(ctx *armadacontext.Context, msg pulsar.Message) { util.RetryUntilSuccess( ctx, func() error { diff --git a/internal/armada/server/submit_from_log_test.go b/internal/armada/server/submit_from_log_test.go index c3479888d06..45368bfe7e2 100644 --- a/internal/armada/server/submit_from_log_test.go +++ b/internal/armada/server/submit_from_log_test.go @@ -1,13 +1,13 @@ package server import ( - ctx "context" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/testfixtures" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -42,7 +42,7 @@ func TestUpdateJobStartTimes(t *testing.T) { }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.NoError(t, err) assert.True(t, ok) @@ -59,7 +59,7 @@ func TestUpdateJobStartTimes_NonExistentJob(t *testing.T) { jobRepository: jobRepo, }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.Nil(t, err) assert.True(t, ok) @@ -75,7 +75,7 @@ func TestUpdateJobStartTimes_RedisError(t *testing.T) { jobRepository: jobRepo, }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.Error(t, err) assert.False(t, ok) diff --git a/internal/armada/server/submit_to_log.go b/internal/armada/server/submit_to_log.go index cf4b12ceca2..aaaee8a35e2 100644 --- a/internal/armada/server/submit_to_log.go +++ b/internal/armada/server/submit_to_log.go @@ -20,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/validation" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" @@ -68,7 +69,8 @@ type PulsarSubmitServer struct { IgnoreJobSubmitChecks bool } -func (srv *PulsarSubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { +func (srv *PulsarSubmitServer) SubmitJobs(grpcCtx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) userId, groups, err := srv.Authorize(ctx, req.Queue, permissions.SubmitAnyJobs, queue.PermissionVerbSubmit) if err != nil { return nil, err @@ -240,7 +242,9 @@ func (srv *PulsarSubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmi return &api.JobSubmitResponse{JobResponseItems: responses}, nil } -func (srv *PulsarSubmitServer) CancelJobs(ctx context.Context, req *api.JobCancelRequest) (*api.CancellationResult, error) { +func (srv *PulsarSubmitServer) CancelJobs(grpcCtx context.Context, req *api.JobCancelRequest) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) + // separate code path for multiple jobs if len(req.JobIds) > 0 { return srv.cancelJobsByIdsQueueJobset(ctx, req.JobIds, req.Queue, req.JobSetId, req.Reason) @@ -328,7 +332,8 @@ func (srv *PulsarSubmitServer) CancelJobs(ctx context.Context, req *api.JobCance } // Assumes all Job IDs are in the queue and job set provided -func (srv *PulsarSubmitServer) cancelJobsByIdsQueueJobset(ctx context.Context, jobIds []string, q, jobSet string, reason string) (*api.CancellationResult, error) { +func (srv *PulsarSubmitServer) cancelJobsByIdsQueueJobset(grpcCtx context.Context, jobIds []string, q, jobSet string, reason string) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if q == "" { return nil, &armadaerrors.ErrInvalidArgument{ Name: "Queue", @@ -390,7 +395,8 @@ func eventSequenceForJobIds(jobIds []string, q, jobSet, userId string, groups [] return sequence, validIds } -func (srv *PulsarSubmitServer) CancelJobSet(ctx context.Context, req *api.JobSetCancelRequest) (*types.Empty, error) { +func (srv *PulsarSubmitServer) CancelJobSet(grpcCtx context.Context, req *api.JobSetCancelRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if req.Queue == "" { return nil, &armadaerrors.ErrInvalidArgument{ Name: "Queue", @@ -492,7 +498,9 @@ func (srv *PulsarSubmitServer) CancelJobSet(ctx context.Context, req *api.JobSet return &types.Empty{}, err } -func (srv *PulsarSubmitServer) ReprioritizeJobs(ctx context.Context, req *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { +func (srv *PulsarSubmitServer) ReprioritizeJobs(grpcCtx context.Context, req *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) + // If either queue or jobSetId is missing, we get the job set and queue associated // with the first job id in the request. // @@ -612,7 +620,7 @@ func (srv *PulsarSubmitServer) ReprioritizeJobs(ctx context.Context, req *api.Jo // Checks that the user has either anyPerm (e.g., permissions.SubmitAnyJobs) or perm (e.g., PermissionVerbSubmit) for this queue. // Returns the userId and groups extracted from the context. func (srv *PulsarSubmitServer) Authorize( - ctx context.Context, + ctx *armadacontext.Context, queueName string, anyPerm permission.Permission, perm queue.PermissionVerb, @@ -694,7 +702,7 @@ func (srv *PulsarSubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueI } // PublishToPulsar sends pulsar messages async -func (srv *PulsarSubmitServer) publishToPulsar(ctx context.Context, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { +func (srv *PulsarSubmitServer) publishToPulsar(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than srv.MaxAllowedMessageSize. sequences = eventutil.CompactEventSequences(sequences) @@ -714,7 +722,7 @@ func jobKey(j *api.Job) string { // getOriginalJobIds returns the mapping between jobId and originalJobId. If the job (or more specifically the clientId // on the job) has not been seen before then jobId -> jobId. If the job has been seen before then jobId -> originalJobId // Note that if srv.KVStore is nil then this function simply returns jobId -> jobId -func (srv *PulsarSubmitServer) getOriginalJobIds(ctx context.Context, apiJobs []*api.Job) (map[string]string, error) { +func (srv *PulsarSubmitServer) getOriginalJobIds(ctx *armadacontext.Context, apiJobs []*api.Job) (map[string]string, error) { // Default is the current id ret := make(map[string]string, len(apiJobs)) for _, apiJob := range apiJobs { @@ -753,7 +761,7 @@ func (srv *PulsarSubmitServer) getOriginalJobIds(ctx context.Context, apiJobs [] return ret, nil } -func (srv *PulsarSubmitServer) storeOriginalJobIds(ctx context.Context, apiJobs []*api.Job) error { +func (srv *PulsarSubmitServer) storeOriginalJobIds(ctx *armadacontext.Context, apiJobs []*api.Job) error { if srv.KVStore == nil { return nil } diff --git a/internal/armada/server/usage.go b/internal/armada/server/usage.go index 92fe54abd45..9c6e1e7800e 100644 --- a/internal/armada/server/usage.go +++ b/internal/armada/server/usage.go @@ -12,6 +12,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -41,7 +42,8 @@ func NewUsageServer( } } -func (s *UsageServer) ReportUsage(ctx context.Context, report *api.ClusterUsageReport) (*types.Empty, error) { +func (s *UsageServer) ReportUsage(grpcCtx context.Context, report *api.ClusterUsageReport) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportUsage] error: %s", err) } diff --git a/internal/armada/server/usage_test.go b/internal/armada/server/usage_test.go index 6464b154880..8f1fa88b30b 100644 --- a/internal/armada/server/usage_test.go +++ b/internal/armada/server/usage_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -26,14 +26,14 @@ func TestUsageServer_ReportUsage(t *testing.T) { err := s.queueRepository.CreateQueue(queue.Queue{Name: "q1", PriorityFactor: 1}) assert.Nil(t, err) - _, err = s.ReportUsage(context.Background(), oneQueueReport(now, cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now, cpu, memory)) assert.Nil(t, err) priority, err := s.usageRepository.GetClusterPriority("clusterA") assert.Nil(t, err) assert.Equal(t, 10.0, priority["q1"], "Priority should be updated for the new cluster.") - _, err = s.ReportUsage(context.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) assert.Nil(t, err) priority, err = s.usageRepository.GetClusterPriority("clusterA") @@ -51,14 +51,14 @@ func TestUsageServer_ReportUsageWithDefinedScarcity(t *testing.T) { err := s.queueRepository.CreateQueue(queue.Queue{Name: "q1", PriorityFactor: 1}) assert.Nil(t, err) - _, err = s.ReportUsage(context.Background(), oneQueueReport(now, cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now, cpu, memory)) assert.Nil(t, err) priority, err := s.usageRepository.GetClusterPriority("clusterA") assert.Nil(t, err) assert.Equal(t, 5.0, priority["q1"], "Priority should be updated for the new cluster.") - _, err = s.ReportUsage(context.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) assert.Nil(t, err) priority, err = s.usageRepository.GetClusterPriority("clusterA") diff --git a/internal/armadactl/analyze.go b/internal/armadactl/analyze.go index de9d29fb5dc..650c0861684 100644 --- a/internal/armadactl/analyze.go +++ b/internal/armadactl/analyze.go @@ -1,11 +1,11 @@ package armadactl import ( - "context" "encoding/json" "fmt" "reflect" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" "github.com/armadaproject/armada/pkg/client/domain" @@ -17,7 +17,7 @@ func (a *App) Analyze(queue string, jobSetId string) error { events := map[string][]*api.Event{} var jobState *domain.WatchContext - client.WatchJobSet(ec, queue, jobSetId, false, true, false, false, context.Background(), func(state *domain.WatchContext, e api.Event) bool { + client.WatchJobSet(ec, queue, jobSetId, false, true, false, false, armadacontext.Background(), func(state *domain.WatchContext, e api.Event) bool { events[e.GetJobId()] = append(events[e.GetJobId()], &e) jobState = state return false diff --git a/internal/armadactl/kube.go b/internal/armadactl/kube.go index ef466f7e6b8..d9b63a0399a 100644 --- a/internal/armadactl/kube.go +++ b/internal/armadactl/kube.go @@ -1,10 +1,10 @@ package armadactl import ( - "context" "fmt" "strings" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" ) @@ -14,7 +14,7 @@ import ( func (a *App) Kube(jobId string, queueName string, jobSetId string, podNumber int, args []string) error { verb := strings.Join(args, " ") return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - state := client.GetJobSetState(c, queueName, jobSetId, context.Background(), true, false, false) + state := client.GetJobSetState(c, queueName, jobSetId, armadacontext.Background(), true, false, false) jobInfo := state.GetJobInfo(jobId) if jobInfo == nil { diff --git a/internal/armadactl/resources.go b/internal/armadactl/resources.go index 4cf4faa653c..8a7f018bc0d 100644 --- a/internal/armadactl/resources.go +++ b/internal/armadactl/resources.go @@ -1,9 +1,9 @@ package armadactl import ( - "context" "fmt" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" ) @@ -11,7 +11,7 @@ import ( // Resources prints the resources used by the jobs in job set with ID jobSetId in the given queue. func (a *App) Resources(queueName string, jobSetId string) error { return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - state := client.GetJobSetState(c, queueName, jobSetId, context.Background(), true, false, false) + state := client.GetJobSetState(c, queueName, jobSetId, armadacontext.Background(), true, false, false) for _, job := range state.GetCurrentState() { fmt.Fprintf(a.Out, "Job ID: %v, maximum used resources: %v\n", job.Job.Id, job.MaxUsedResources) diff --git a/internal/armadactl/watch.go b/internal/armadactl/watch.go index fd0d842d5cf..872a01388c8 100644 --- a/internal/armadactl/watch.go +++ b/internal/armadactl/watch.go @@ -1,12 +1,12 @@ package armadactl import ( - "context" "encoding/json" "fmt" "reflect" "time" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" "github.com/armadaproject/armada/pkg/client/domain" @@ -16,7 +16,7 @@ import ( func (a *App) Watch(queue string, jobSetId string, raw bool, exitOnInactive bool, forceNewEvents bool, forceLegacyEvents bool) error { fmt.Fprintf(a.Out, "Watching job set %s\n", jobSetId) return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - client.WatchJobSet(c, queue, jobSetId, true, true, forceNewEvents, forceLegacyEvents, context.Background(), func(state *domain.WatchContext, event api.Event) bool { + client.WatchJobSet(c, queue, jobSetId, true, true, forceNewEvents, forceLegacyEvents, armadacontext.Background(), func(state *domain.WatchContext, event api.Event) bool { if raw { data, err := json.Marshal(event) if err != nil { diff --git a/internal/binoculars/server/binoculars.go b/internal/binoculars/server/binoculars.go index 4497573a04d..0a08237058f 100644 --- a/internal/binoculars/server/binoculars.go +++ b/internal/binoculars/server/binoculars.go @@ -8,6 +8,7 @@ import ( "github.com/armadaproject/armada/internal/binoculars/service" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api/binoculars" ) @@ -27,7 +28,7 @@ func NewBinocularsServer(logService service.LogService, cordonService service.Co func (b *BinocularsServer) Logs(ctx context.Context, request *binoculars.LogRequest) (*binoculars.LogResponse, error) { principal := authorization.GetPrincipal(ctx) - logLines, err := b.logService.GetLogs(ctx, &service.LogParams{ + logLines, err := b.logService.GetLogs(armadacontext.FromGrpcCtx(ctx), &service.LogParams{ Principal: principal, Namespace: request.PodNamespace, PodName: common.PodNamePrefix + request.JobId + "-" + strconv.Itoa(int(request.PodNumber)), @@ -42,7 +43,7 @@ func (b *BinocularsServer) Logs(ctx context.Context, request *binoculars.LogRequ } func (b *BinocularsServer) Cordon(ctx context.Context, request *binoculars.CordonRequest) (*types.Empty, error) { - err := b.cordonService.CordonNode(ctx, request) + err := b.cordonService.CordonNode(armadacontext.FromGrpcCtx(ctx), request) if err != nil { return nil, err } diff --git a/internal/binoculars/service/cordon.go b/internal/binoculars/service/cordon.go index 8d850bca8ec..584da9bf4ca 100644 --- a/internal/binoculars/service/cordon.go +++ b/internal/binoculars/service/cordon.go @@ -1,7 +1,6 @@ package service import ( - "context" "encoding/json" "fmt" "strings" @@ -14,6 +13,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/binoculars/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/internal/common/cluster" @@ -23,7 +23,7 @@ import ( const userTemplate = "" type CordonService interface { - CordonNode(ctx context.Context, request *binoculars.CordonRequest) error + CordonNode(ctx *armadacontext.Context, request *binoculars.CordonRequest) error } type KubernetesCordonService struct { @@ -44,7 +44,7 @@ func NewKubernetesCordonService( } } -func (c *KubernetesCordonService) CordonNode(ctx context.Context, request *binoculars.CordonRequest) error { +func (c *KubernetesCordonService) CordonNode(ctx *armadacontext.Context, request *binoculars.CordonRequest) error { err := checkPermission(c.permissionChecker, ctx, permissions.CordonNodes) if err != nil { return status.Errorf(codes.PermissionDenied, err.Error()) @@ -91,7 +91,7 @@ func GetPatchBytes(patchData *nodePatch) ([]byte, error) { return json.Marshal(patchData) } -func checkPermission(p authorization.PermissionChecker, ctx context.Context, permission permission.Permission) error { +func checkPermission(p authorization.PermissionChecker, ctx *armadacontext.Context, permission permission.Permission) error { if !p.UserHasPermission(ctx, permission) { return fmt.Errorf("user %s does not have permission %s", authorization.GetPrincipal(ctx).GetName(), permission) } diff --git a/internal/binoculars/service/cordon_test.go b/internal/binoculars/service/cordon_test.go index 5a1cce961b9..eadac72fd8e 100644 --- a/internal/binoculars/service/cordon_test.go +++ b/internal/binoculars/service/cordon_test.go @@ -6,6 +6,7 @@ import ( "fmt" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -19,6 +20,7 @@ import ( clientTesting "k8s.io/client-go/testing" "github.com/armadaproject/armada/internal/binoculars/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/pkg/api/binoculars" @@ -79,7 +81,7 @@ func TestCordonNode(t *testing.T) { cordonService, client := setupTest(t, cordonConfig, FakePermissionChecker{ReturnValue: true}) ctx := authorization.WithPrincipal(context.Background(), principal) - err := cordonService.CordonNode(ctx, &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), &binoculars.CordonRequest{ NodeName: defaultNode.Name, }) assert.Nil(t, err) @@ -96,7 +98,7 @@ func TestCordonNode(t *testing.T) { assert.Equal(t, patch, tc.expectedPatch) // Assert resulting node is in expected state - node, err := client.CoreV1().Nodes().Get(context.Background(), defaultNode.Name, metav1.GetOptions{}) + node, err := client.CoreV1().Nodes().Get(armadacontext.Background(), defaultNode.Name, metav1.GetOptions{}) assert.Nil(t, err) assert.Equal(t, node.Spec.Unschedulable, true) assert.Equal(t, node.Labels, tc.expectedLabels) @@ -107,7 +109,7 @@ func TestCordonNode(t *testing.T) { func TestCordonNode_InvalidNodeName(t *testing.T) { cordonService, _ := setupTest(t, defaultCordonConfig, FakePermissionChecker{ReturnValue: true}) - err := cordonService.CordonNode(context.Background(), &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.Background(), &binoculars.CordonRequest{ NodeName: "non-existent-node", }) @@ -117,7 +119,7 @@ func TestCordonNode_InvalidNodeName(t *testing.T) { func TestCordonNode_Unauthenticated(t *testing.T) { cordonService, _ := setupTest(t, defaultCordonConfig, FakePermissionChecker{ReturnValue: false}) - err := cordonService.CordonNode(context.Background(), &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.Background(), &binoculars.CordonRequest{ NodeName: defaultNode.Name, }) @@ -131,7 +133,7 @@ func setupTest(t *testing.T, config configuration.CordonConfiguration, permissio client := fake.NewSimpleClientset() clientProvider := &FakeClientProvider{FakeClient: client} - _, err := client.CoreV1().Nodes().Create(context.Background(), defaultNode, metav1.CreateOptions{}) + _, err := client.CoreV1().Nodes().Create(armadacontext.Background(), defaultNode, metav1.CreateOptions{}) require.NoError(t, err) client.Fake.ClearActions() diff --git a/internal/binoculars/service/logs.go b/internal/binoculars/service/logs.go index 49801758292..ac72215f67e 100644 --- a/internal/binoculars/service/logs.go +++ b/internal/binoculars/service/logs.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "strings" "time" @@ -10,13 +9,14 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/cluster" "github.com/armadaproject/armada/pkg/api/binoculars" ) type LogService interface { - GetLogs(ctx context.Context, params *LogParams) ([]*binoculars.LogLine, error) + GetLogs(ctx *armadacontext.Context, params *LogParams) ([]*binoculars.LogLine, error) } type LogParams struct { @@ -37,7 +37,7 @@ func NewKubernetesLogService(clientProvider cluster.KubernetesClientProvider) *K return &KubernetesLogService{clientProvider: clientProvider} } -func (l *KubernetesLogService) GetLogs(ctx context.Context, params *LogParams) ([]*binoculars.LogLine, error) { +func (l *KubernetesLogService) GetLogs(ctx *armadacontext.Context, params *LogParams) ([]*binoculars.LogLine, error) { client, err := l.clientProvider.ClientForUser(params.Principal.GetName(), params.Principal.GetGroupNames()) if err != nil { return nil, err diff --git a/internal/common/app/app.go b/internal/common/app/app.go index bd35f7a5a8f..25ce1e828b0 100644 --- a/internal/common/app/app.go +++ b/internal/common/app/app.go @@ -1,15 +1,16 @@ package app import ( - "context" "os" "os/signal" "syscall" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // CreateContextWithShutdown returns a context that will report done when a SIGTERM is received -func CreateContextWithShutdown() context.Context { - ctx, cancel := context.WithCancel(context.Background()) +func CreateContextWithShutdown() *armadacontext.Context { + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go func() { diff --git a/internal/common/armadacontext/armada_context.go b/internal/common/armadacontext/armada_context.go new file mode 100644 index 00000000000..a6985ee5df7 --- /dev/null +++ b/internal/common/armadacontext/armada_context.go @@ -0,0 +1,107 @@ +package armadacontext + +import ( + "context" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +// Context is an extension of Go's context which also includes a logger. This allows us to pass round a contextual logger +// while retaining type-safety +type Context struct { + context.Context + Log *logrus.Entry +} + +// Background creates an empty context with a default logger. It is analogous to context.Background() +func Background() *Context { + return &Context{ + Context: context.Background(), + Log: logrus.NewEntry(logrus.New()), + } +} + +// TODO creates an empty context with a default logger. It is analogous to context.TODO() +func TODO() *Context { + return &Context{ + Context: context.TODO(), + Log: logrus.NewEntry(logrus.New()), + } +} + +// FromGrpcCtx creates a context where the logger is extracted via ctxlogrus's Extract() method. +// Note that this will result in a no-op logger if a logger hasn't already been inserted into the context via ctxlogrus +func FromGrpcCtx(ctx context.Context) *Context { + log := ctxlogrus.Extract(ctx) + return New(ctx, log) +} + +// New returns an armada context that encapsulates both a go context and a logger +func New(ctx context.Context, log *logrus.Entry) *Context { + return &Context{ + Context: ctx, + Log: log, + } +} + +// WithCancel returns a copy of parent with a new Done channel. It is analogous to context.WithCancel() +func WithCancel(parent *Context) (*Context, context.CancelFunc) { + c, cancel := context.WithCancel(parent.Context) + return &Context{ + Context: c, + Log: parent.Log, + }, cancel +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted to be no later than d. +// It is analogous to context.WithDeadline() +func WithDeadline(parent *Context, d time.Time) (*Context, context.CancelFunc) { + c, cancel := context.WithDeadline(parent.Context, d) + return &Context{ + Context: c, + Log: parent.Log, + }, cancel +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). It is analogous to context.WithTimeout() +func WithTimeout(parent *Context, timeout time.Duration) (*Context, context.CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithLogField returns a copy of parent with the supplied key-value added to the logger +func WithLogField(parent *Context, key string, val interface{}) *Context { + return &Context{ + Context: parent.Context, + Log: parent.Log.WithField(key, val), + } +} + +// WithLogFields returns a copy of parent with the supplied key-values added to the logger +func WithLogFields(parent *Context, fields logrus.Fields) *Context { + return &Context{ + Context: parent.Context, + Log: parent.Log.WithFields(fields), + } +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. It is analogous to context.WithValue() +func WithValue(parent *Context, key, val any) *Context { + return &Context{ + Context: context.WithValue(parent, key, val), + Log: parent.Log, + } +} + +// ErrGroup returns a new Error Group and an associated Context derived from ctx. +// It is analogous to errgroup.WithContext(ctx) +func ErrGroup(ctx *Context) (*errgroup.Group, *Context) { + group, goctx := errgroup.WithContext(ctx) + return group, &Context{ + Context: goctx, + Log: ctx.Log, + } +} diff --git a/internal/common/armadacontext/armada_context_test.go b/internal/common/armadacontext/armada_context_test.go new file mode 100644 index 00000000000..a98d7b611df --- /dev/null +++ b/internal/common/armadacontext/armada_context_test.go @@ -0,0 +1,89 @@ +package armadacontext + +import ( + "context" + "testing" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +var defaultLogger = logrus.WithField("foo", "bar") + +func TestNew(t *testing.T) { + ctx := New(context.Background(), defaultLogger) + require.Equal(t, defaultLogger, ctx.Log) + require.Equal(t, context.Background(), ctx.Context) +} + +func TestFromGrpcContext(t *testing.T) { + grpcCtx := ctxlogrus.ToContext(context.Background(), defaultLogger) + ctx := FromGrpcCtx(grpcCtx) + require.Equal(t, grpcCtx, ctx.Context) + require.Equal(t, defaultLogger, ctx.Log) +} + +func TestBackground(t *testing.T) { + ctx := Background() + require.Equal(t, ctx.Context, context.Background()) +} + +func TestTODO(t *testing.T) { + ctx := TODO() + require.Equal(t, ctx.Context, context.TODO()) +} + +func TestWithLogField(t *testing.T) { + ctx := WithLogField(Background(), "fish", "chips") + require.Equal(t, context.Background(), ctx.Context) + require.Equal(t, logrus.Fields{"fish": "chips"}, ctx.Log.Data) +} + +func TestWithLogFields(t *testing.T) { + ctx := WithLogFields(Background(), logrus.Fields{"fish": "chips", "salt": "pepper"}) + require.Equal(t, context.Background(), ctx.Context) + require.Equal(t, logrus.Fields{"fish": "chips", "salt": "pepper"}, ctx.Log.Data) +} + +func TestWithTimeout(t *testing.T) { + ctx, _ := WithTimeout(Background(), 100*time.Millisecond) + testDeadline(t, ctx) +} + +func TestWithDeadline(t *testing.T) { + ctx, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + testDeadline(t, ctx) +} + +func TestWithValue(t *testing.T) { + ctx := WithValue(Background(), "foo", "bar") + require.Equal(t, "bar", ctx.Value("foo")) +} + +func testDeadline(t *testing.T, c *Context) { + t.Helper() + d := quiescent(t) + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + t.Fatalf("context not timed out after %v", d) + case <-c.Done(): + } + if e := c.Err(); e != context.DeadlineExceeded { + t.Errorf("c.Err() == %v; want %v", e, context.DeadlineExceeded) + } +} + +func quiescent(t *testing.T) time.Duration { + deadline, ok := t.Deadline() + if !ok { + return 5 * time.Second + } + + const arbitraryCleanupMargin = 1 * time.Second + return time.Until(deadline) - arbitraryCleanupMargin +} diff --git a/internal/common/auth/authorization/kubernetes_test.go b/internal/common/auth/authorization/kubernetes_test.go index 9493c71f80a..eef827f9add 100644 --- a/internal/common/auth/authorization/kubernetes_test.go +++ b/internal/common/auth/authorization/kubernetes_test.go @@ -10,11 +10,10 @@ import ( "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" - authv1 "k8s.io/api/authentication/v1" - "k8s.io/apimachinery/pkg/util/clock" - "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" + authv1 "k8s.io/api/authentication/v1" + "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/common/auth/configuration" ) diff --git a/internal/common/certs/cached_certificate.go b/internal/common/certs/cached_certificate.go index 2588d0f5b50..72b7f6ea250 100644 --- a/internal/common/certs/cached_certificate.go +++ b/internal/common/certs/cached_certificate.go @@ -1,13 +1,14 @@ package certs import ( - "context" "crypto/tls" "os" "sync" "time" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) type CachedCertificateService struct { @@ -52,7 +53,7 @@ func (c *CachedCertificateService) updateCertificate(certificate *tls.Certificat c.certificate = certificate } -func (c *CachedCertificateService) Run(ctx context.Context) { +func (c *CachedCertificateService) Run(ctx *armadacontext.Context) { ticker := time.NewTicker(c.refreshInterval) for { select { diff --git a/internal/common/certs/cached_certificate_test.go b/internal/common/certs/cached_certificate_test.go index 7687c80fd63..4edd3efd376 100644 --- a/internal/common/certs/cached_certificate_test.go +++ b/internal/common/certs/cached_certificate_test.go @@ -2,7 +2,6 @@ package certs import ( "bytes" - "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -16,6 +15,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -96,7 +97,7 @@ func TestCachedCertificateService_ReloadsCertPeriodically_WhenUsingRun(t *testin assert.Equal(t, cert, cachedCertService.GetCertificate()) go func() { - cachedCertService.Run(context.Background()) + cachedCertService.Run(armadacontext.Background()) }() newCert, certData, keyData := createCerts(t) diff --git a/internal/common/client.go b/internal/common/client.go index 0b44c374d0b..afc5bb5c597 100644 --- a/internal/common/client.go +++ b/internal/common/client.go @@ -3,8 +3,10 @@ package common import ( "context" "time" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func ContextWithDefaultTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), 10*time.Second) +func ContextWithDefaultTimeout() (*armadacontext.Context, context.CancelFunc) { + return armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) } diff --git a/internal/common/database/db_testutil.go b/internal/common/database/db_testutil.go index a36affdef73..416b348d7d8 100644 --- a/internal/common/database/db_testutil.go +++ b/internal/common/database/db_testutil.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "github.com/jackc/pgx/v5" @@ -10,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" ) @@ -17,7 +17,7 @@ import ( // migrations: perform the list of migrations before entering the action callback // action: callback for client code func WithTestDb(migrations []Migration, action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() // Connect and create a dedicated database for the test dbName := "test_" + util.NewULID() @@ -67,7 +67,7 @@ func WithTestDb(migrations []Migration, action func(db *pgxpool.Pool) error) err // config: PostgresConfig to specify connection details to database // action: callback for client code func WithTestDbCustom(migrations []Migration, config configuration.PostgresConfig, action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() testDbPool, err := OpenPgxPool(config) if err != nil { diff --git a/internal/common/database/functions.go b/internal/common/database/functions.go index 5446f7cd0e1..17f3334efab 100644 --- a/internal/common/database/functions.go +++ b/internal/common/database/functions.go @@ -1,7 +1,6 @@ package database import ( - "context" "database/sql" "fmt" "strings" @@ -13,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" ) func CreateConnectionString(values map[string]string) string { @@ -26,20 +26,20 @@ func CreateConnectionString(values map[string]string) string { } func OpenPgxConn(config configuration.PostgresConfig) (*pgx.Conn, error) { - db, err := pgx.Connect(context.Background(), CreateConnectionString(config.Connection)) + db, err := pgx.Connect(armadacontext.Background(), CreateConnectionString(config.Connection)) if err != nil { return nil, err } - err = db.Ping(context.Background()) + err = db.Ping(armadacontext.Background()) return db, err } func OpenPgxPool(config configuration.PostgresConfig) (*pgxpool.Pool, error) { - db, err := pgxpool.New(context.Background(), CreateConnectionString(config.Connection)) + db, err := pgxpool.New(armadacontext.Background(), CreateConnectionString(config.Connection)) if err != nil { return nil, err } - err = db.Ping(context.Background()) + err = db.Ping(armadacontext.Background()) return db, err } diff --git a/internal/common/database/migrations.go b/internal/common/database/migrations.go index 164c75b313d..b515c94f7fb 100644 --- a/internal/common/database/migrations.go +++ b/internal/common/database/migrations.go @@ -2,7 +2,6 @@ package database import ( "bytes" - "context" "io/fs" "path" "sort" @@ -11,6 +10,8 @@ import ( stakikfs "github.com/rakyll/statik/fs" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Migration represents a single, versioned database migration script @@ -28,7 +29,7 @@ func NewMigration(id int, name string, sql string) Migration { } } -func UpdateDatabase(ctx context.Context, db Querier, migrations []Migration) error { +func UpdateDatabase(ctx *armadacontext.Context, db Querier, migrations []Migration) error { log.Info("Updating postgres...") version, err := readVersion(ctx, db) if err != nil { @@ -55,7 +56,7 @@ func UpdateDatabase(ctx context.Context, db Querier, migrations []Migration) err return nil } -func readVersion(ctx context.Context, db Querier) (int, error) { +func readVersion(ctx *armadacontext.Context, db Querier) (int, error) { _, err := db.Exec(ctx, `CREATE SEQUENCE IF NOT EXISTS database_version START WITH 0 MINVALUE 0;`) if err != nil { @@ -75,7 +76,7 @@ func readVersion(ctx context.Context, db Querier) (int, error) { return version, err } -func setVersion(ctx context.Context, db Querier, version int) error { +func setVersion(ctx *armadacontext.Context, db Querier, version int) error { _, err := db.Exec(ctx, `SELECT setval('database_version', $1)`, version) return err } diff --git a/internal/common/database/types/types.go b/internal/common/database/types/types.go index eb4f8d426be..2171d10bad1 100644 --- a/internal/common/database/types/types.go +++ b/internal/common/database/types/types.go @@ -1,10 +1,10 @@ package types import ( - "context" "database/sql" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" ) type DatabaseConnection interface { @@ -16,25 +16,25 @@ type DatabaseConnection interface { // executing queries, and starting transactions. type DatabaseConn interface { // Close closes the database connection. It returns any error encountered during the closing operation. - Close(context.Context) error + Close(*armadacontext.Context) error // Ping pings the database to check the connection. It returns any error encountered during the ping operation. - Ping(context.Context) error + Ping(*armadacontext.Context) error // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // QueryRow executes a query that returns one row. It returns a DatabaseRow interface representing the result row, and any error encountered. - QueryRow(context.Context, string, ...any) DatabaseRow + QueryRow(*armadacontext.Context, string, ...any) DatabaseRow // BeginTx starts a transcation with the given DatabaseTxOptions, or returns an error if any occurred. - BeginTx(context.Context, DatabaseTxOptions) (DatabaseTx, error) + BeginTx(*armadacontext.Context, DatabaseTxOptions) (DatabaseTx, error) // BeginTxFunc starts a transaction and executes the given function within the transaction. It the function runs successfully, BeginTxFunc commits the transaction, otherwise it rolls back and return an errorr. - BeginTxFunc(context.Context, DatabaseTxOptions, func(DatabaseTx) error) error + BeginTxFunc(*armadacontext.Context, DatabaseTxOptions, func(DatabaseTx) error) error } type DatabaseTxOptions struct { @@ -47,52 +47,52 @@ type DatabaseTxOptions struct { // managing transactions, and performing bulk insertions. type DatabaseTx interface { // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. // It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // QueryRow executes a query that returns one row. // It returns a DatabaseRow interface representing the result row, and any error encountered. - QueryRow(context.Context, string, ...any) DatabaseRow + QueryRow(*armadacontext.Context, string, ...any) DatabaseRow // CopyFrom performs a bulk insertion of data into a specified table. // It accepts the table name, column names, and a slice of rows representing the data to be inserted. It returns the number of rows inserted and any error encountered. - CopyFrom(ctx context.Context, tableName string, columnNames []string, rows [][]any) (int64, error) + CopyFrom(ctx *armadacontext.Context, tableName string, columnNames []string, rows [][]any) (int64, error) // Commit commits the transaction. It returns any error encountered during the commit operation. - Commit(context.Context) error + Commit(*armadacontext.Context) error // Rollback rolls back the transaction. It returns any error encountered during the rollback operation. - Rollback(context.Context) error + Rollback(*armadacontext.Context) error } // DatabasePool represents a database connection pool interface that provides methods for acquiring and managing database connections. type DatabasePool interface { // Acquire acquires a database connection from the pool. // It takes a context and returns a DatabaseConn representing the acquired connection and any encountered error. - Acquire(context.Context) (DatabaseConn, error) + Acquire(*armadacontext.Context) (DatabaseConn, error) // Ping pings the database to check the connection. It returns any error encountered during the ping operation. - Ping(context.Context) error + Ping(*armadacontext.Context) error // Close closes the database connection. It returns any error encountered during the closing operation. Close() // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. // It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // BeginTx starts a transcation with the given DatabaseTxOptions, or returns an error if any occurred. - BeginTx(context.Context, DatabaseTxOptions) (DatabaseTx, error) + BeginTx(*armadacontext.Context, DatabaseTxOptions) (DatabaseTx, error) // BeginTxFunc starts a transaction and executes the given function within the transaction. // It the function runs successfully, BeginTxFunc commits the transaction, otherwise it rolls back and return an error. - BeginTxFunc(context.Context, DatabaseTxOptions, func(DatabaseTx) error) error + BeginTxFunc(*armadacontext.Context, DatabaseTxOptions, func(DatabaseTx) error) error } // DatabaseRow represents a single row in a result set. diff --git a/internal/common/database/upsert.go b/internal/common/database/upsert.go index 23f27164f9b..5df05c67918 100644 --- a/internal/common/database/upsert.go +++ b/internal/common/database/upsert.go @@ -1,19 +1,19 @@ package database import ( - "context" "fmt" "reflect" "strings" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func UpsertWithTransaction[T any](ctx context.Context, db *pgxpool.Pool, tableName string, records []T) error { +func UpsertWithTransaction[T any](ctx *armadacontext.Context, db *pgxpool.Pool, tableName string, records []T) error { if len(records) == 0 { return nil } @@ -50,7 +50,7 @@ func UpsertWithTransaction[T any](ctx context.Context, db *pgxpool.Pool, tableNa // // ) // I.e., it should omit everything before and after the "(" and ")", respectively. -func Upsert[T any](ctx context.Context, tx pgx.Tx, tableName string, records []T) error { +func Upsert[T any](ctx *armadacontext.Context, tx pgx.Tx, tableName string, records []T) error { if len(records) < 1 { return nil } diff --git a/internal/common/database/upsert_test.go b/internal/common/database/upsert_test.go index b1329921c1e..638d15ac494 100644 --- a/internal/common/database/upsert_test.go +++ b/internal/common/database/upsert_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -9,6 +8,8 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Used for tests. @@ -55,7 +56,7 @@ func TestNamesValuesFromRecordPointer(t *testing.T) { } func TestUpsertWithTransaction(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Hour) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Hour) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Insert rows, read them back, and compare. @@ -90,7 +91,7 @@ func TestUpsertWithTransaction(t *testing.T) { } func TestConcurrency(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Each thread inserts non-overlapping rows, reads them back, and compares. @@ -125,7 +126,7 @@ func TestConcurrency(t *testing.T) { } func TestAutoIncrement(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Insert two rows. These should automatically get auto-incrementing serial numbers. @@ -207,7 +208,7 @@ func setMessageToExecutor(runs []Record, executor string) { } func selectRecords(db *pgxpool.Pool) ([]Record, error) { - rows, err := db.Query(context.Background(), fmt.Sprintf("SELECT id, message, value, serial FROM %s order by value", TABLE_NAME)) + rows, err := db.Query(armadacontext.Background(), fmt.Sprintf("SELECT id, message, value, serial FROM %s order by value", TABLE_NAME)) if err != nil { return nil, err } diff --git a/internal/common/etcdhealth/etcdhealth.go b/internal/common/etcdhealth/etcdhealth.go index 804a89542f4..49be27a22fe 100644 --- a/internal/common/etcdhealth/etcdhealth.go +++ b/internal/common/etcdhealth/etcdhealth.go @@ -1,7 +1,6 @@ package etcdhealth import ( - "context" "sync" "time" @@ -9,6 +8,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/healthmonitor" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/metrics" @@ -184,7 +184,7 @@ func (srv *EtcdReplicaHealthMonitor) sizeFraction() float64 { return srv.etcdSizeBytes / srv.etcdCapacityBytes } -func (srv *EtcdReplicaHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { +func (srv *EtcdReplicaHealthMonitor) Run(ctx *armadacontext.Context, log *logrus.Entry) error { log = log.WithField("service", "EtcdHealthMonitor") log.Info("starting etcd health monitor") defer log.Info("stopping etcd health monitor") @@ -264,7 +264,7 @@ func (srv *EtcdReplicaHealthMonitor) setCapacityBytesFromMetrics(metrics map[str // BlockUntilNextMetricsCollection blocks until the next metrics collection has completed, // or until ctx is cancelled, whichever occurs first. -func (srv *EtcdReplicaHealthMonitor) BlockUntilNextMetricsCollection(ctx context.Context) { +func (srv *EtcdReplicaHealthMonitor) BlockUntilNextMetricsCollection(ctx *armadacontext.Context) { c := make(chan struct{}) srv.mu.Lock() srv.watchers = append(srv.watchers, c) diff --git a/internal/common/etcdhealth/etcdhealth_test.go b/internal/common/etcdhealth/etcdhealth_test.go index 22435861a61..474d4df0e3a 100644 --- a/internal/common/etcdhealth/etcdhealth_test.go +++ b/internal/common/etcdhealth/etcdhealth_test.go @@ -1,14 +1,13 @@ package etcdhealth import ( - "context" "testing" "time" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/healthmonitor" "github.com/armadaproject/armada/internal/common/metrics" ) @@ -24,9 +23,9 @@ func TestEtcdReplicaHealthMonitor(t *testing.T) { assert.NoError(t, err) // Start the metrics collection service. - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) defer cancel() - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) g.Go(func() error { return hm.Run(ctx, logrus.NewEntry(logrus.New())) }) // Should still be unavailable due to missing metrics. diff --git a/internal/common/eventutil/eventutil.go b/internal/common/eventutil/eventutil.go index 05ee5d473c9..10d5baf4885 100644 --- a/internal/common/eventutil/eventutil.go +++ b/internal/common/eventutil/eventutil.go @@ -1,7 +1,6 @@ package eventutil import ( - "context" "fmt" "math" "time" @@ -14,6 +13,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -25,7 +25,7 @@ import ( // UnmarshalEventSequence returns an EventSequence object contained in a byte buffer // after validating that the resulting EventSequence is valid. -func UnmarshalEventSequence(ctx context.Context, payload []byte) (*armadaevents.EventSequence, error) { +func UnmarshalEventSequence(ctx *armadacontext.Context, payload []byte) (*armadaevents.EventSequence, error) { sequence := &armadaevents.EventSequence{} err := proto.Unmarshal(payload, sequence) if err != nil { diff --git a/internal/common/eventutil/sequence_from_message.go b/internal/common/eventutil/sequence_from_message.go deleted file mode 100644 index cc1749c392e..00000000000 --- a/internal/common/eventutil/sequence_from_message.go +++ /dev/null @@ -1,193 +0,0 @@ -package eventutil - -import ( - "context" - "time" - - "github.com/apache/pulsar-client-go/pulsar" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" - - "github.com/armadaproject/armada/internal/common/logging" - "github.com/armadaproject/armada/pkg/armadaevents" -) - -// PulsarToChannel is a service for receiving messages from Pulsar and forwarding those on C. -type SequenceFromMessage struct { - In chan pulsar.Message - Out chan *EventSequenceWithMessageIds -} - -// EventSequenceWithMessageIds bundles an event sequence with -// all the ids of all Pulsar messages that were consumed to produce it. -type EventSequenceWithMessageIds struct { - Sequence *armadaevents.EventSequence - MessageIds []pulsar.MessageID -} - -func NewSequenceFromMessage(in chan pulsar.Message) *SequenceFromMessage { - return &SequenceFromMessage{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - } -} - -func (srv *SequenceFromMessage) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case msg := <-srv.In: - if msg == nil { - break - } - sequence, err := UnmarshalEventSequence(ctx, msg.Payload()) - if err != nil { - logging.WithStacktrace(log, err).WithField("messageid", msg.ID()).Error("failed to unmarshal event sequence") - break - } - - sequenceWithMessageIds := &EventSequenceWithMessageIds{ - Sequence: sequence, - MessageIds: []pulsar.MessageID{msg.ID()}, - } - select { - case <-ctx.Done(): - case srv.Out <- sequenceWithMessageIds: - } - } - } -} - -// SequenceCompacter reads sequences and produces compacted sequences. -// Compacted sequences are created by combining events in sequences with the -type SequenceCompacter struct { - In chan *EventSequenceWithMessageIds - Out chan *EventSequenceWithMessageIds - // Buffer messages for at most this long before forwarding on the outgoing channel. - Interval time.Duration - // Max number of events to buffer. - MaxEvents int - // Buffer of events to be compacted and sent. - buffer []*EventSequenceWithMessageIds - // Number of events collected so far. - numEvents int -} - -func NewSequenceCompacter(in chan *EventSequenceWithMessageIds) *SequenceCompacter { - return &SequenceCompacter{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - Interval: 5 * time.Second, - MaxEvents: 10000, - } -} - -func (srv *SequenceCompacter) Run(ctx context.Context) error { - ticker := time.NewTicker(srv.Interval) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - err := srv.compactAndSend(ctx) - if err != nil { - return err - } - case sequenceWithIds := <-srv.In: - if sequenceWithIds == nil || sequenceWithIds.Sequence == nil { - break - } - srv.buffer = append(srv.buffer, sequenceWithIds) - srv.numEvents += len(sequenceWithIds.Sequence.Events) - if srv.numEvents > srv.MaxEvents { - err := srv.compactAndSend(ctx) - if err != nil { - return err - } - } - } - } -} - -func (srv *SequenceCompacter) compactAndSend(ctx context.Context) error { - if len(srv.buffer) == 0 { - return nil - } - - // Compact the event sequences. - // Note that we can't be sure of the number of message ids. - messageIds := make([]pulsar.MessageID, 0, len(srv.buffer)) - sequences := make([]*armadaevents.EventSequence, len(srv.buffer)) - for i, sequenceWithIds := range srv.buffer { - messageIds = append(messageIds, sequenceWithIds.MessageIds...) - sequences[i] = sequenceWithIds.Sequence - } - sequences = CompactEventSequences(sequences) - - for i, sequence := range sequences { - sequenceWithIds := &EventSequenceWithMessageIds{ - Sequence: sequence, - } - - // Add all message ids to the last sequence to be produced. - // To avoid later ack'ing messages the data of which has not yet been processed. - if i == len(sequences)-1 { - sequenceWithIds.MessageIds = messageIds - } - - select { - case <-ctx.Done(): - return ctx.Err() - case srv.Out <- sequenceWithIds: - } - } - - // Empty the buffer. - srv.buffer = nil - srv.numEvents = 0 - - return nil -} - -// EventFilter calls filter once for each event, -// and events for which filter returns false are discarded. -type EventFilter struct { - In chan *EventSequenceWithMessageIds - Out chan *EventSequenceWithMessageIds - // Filter function. Discard on returning false. - filter func(*armadaevents.EventSequence_Event) bool -} - -func NewEventFilter(in chan *EventSequenceWithMessageIds, filter func(*armadaevents.EventSequence_Event) bool) *EventFilter { - return &EventFilter{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - filter: filter, - } -} - -func (srv *EventFilter) Run(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - case sequenceWithIds := <-srv.In: - if sequenceWithIds == nil { - break - } - events := make([]*armadaevents.EventSequence_Event, 0, len(sequenceWithIds.Sequence.Events)) - for _, event := range sequenceWithIds.Sequence.Events { - if srv.filter(event) { - events = append(events, event) - } - } - sequenceWithIds.Sequence.Events = events - - select { - case <-ctx.Done(): - case srv.Out <- sequenceWithIds: - } - } - } -} diff --git a/internal/common/eventutil/sequence_from_message_test.go b/internal/common/eventutil/sequence_from_message_test.go deleted file mode 100644 index a4a1812b207..00000000000 --- a/internal/common/eventutil/sequence_from_message_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package eventutil - -// import ( -// "context" -// "testing" -// "time" - -// "github.com/armadaproject/armada/internal/pulsarutils" -// "github.com/armadaproject/armada/pkg/armadaevents" -// "github.com/apache/pulsar-client-go/pulsar" -// ) - -// func TestSequenceCompacter(t *testing.T) { - -// } - -// func TestEventFilter(t *testing.T) { -// tests := map[string]struct { -// filter func(*armadaevents.EventSequence_Event) bool -// n int // Number of event expected to pass the filter -// }{ -// "filter all": { -// filter: func(a *armadaevents.EventSequence_Event) bool { -// return false -// }, -// n: 0, -// }, -// "filter none": { -// filter: func(a *armadaevents.EventSequence_Event) bool { -// return true -// }, -// n: 1, -// }, -// } -// for name, tc := range tests { -// t.Run(name, func(t *testing.T) { -// C := make(chan *EventSequenceWithMessageIds, 1) -// eventFilter := NewEventFilter(C, tc.filter) -// ctx, _ := context.WithTimeout(context.Background(), time.Second) -// sequence := &EventSequenceWithMessageIds{ -// Sequence: &armadaevents.EventSequence{ -// Events: []*armadaevents.EventSequence_Event{ -// {Event: nil}, -// {Event: &armadaevents.EventSequence_Event_SubmitJob{}}, -// }, -// }, -// MessageIds: []pulsar.MessageID{pulsarutils.New(0, i, 0, 0)}, -// } -// C <- sequence - -// }) -// } -// } - -// func generateEvents(ctx context.Context, out chan *EventSequenceWithMessageIds) error { -// var i int64 -// for { -// sequence := EventSequenceWithMessageIds{ -// Sequence: &armadaevents.EventSequence{ -// Events: []*armadaevents.EventSequence_Event{ -// {Event: nil}, -// {Event: &armadaevents.EventSequence_Event_SubmitJob{}}, -// }, -// }, -// MessageIds: []pulsar.MessageID{pulsarutils.New(0, i, 0, 0)}, -// } -// select { -// case <-ctx.Done(): -// return ctx.Err() -// case out <- &sequence: -// } -// } -// } diff --git a/internal/common/grpc/grpc.go b/internal/common/grpc/grpc.go index 5f73c3801c0..43707dffadf 100644 --- a/internal/common/grpc/grpc.go +++ b/internal/common/grpc/grpc.go @@ -1,7 +1,6 @@ package grpc import ( - "context" "crypto/tls" "fmt" "net" @@ -23,6 +22,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/certs" @@ -91,7 +91,7 @@ func CreateGrpcServer( if tlsConfig.Enabled { cachedCertificateService := certs.NewCachedCertificateService(tlsConfig.CertPath, tlsConfig.KeyPath, time.Minute) go func() { - cachedCertificateService.Run(context.Background()) + cachedCertificateService.Run(armadacontext.Background()) }() tlsCreds := credentials.NewTLS(&tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -130,7 +130,7 @@ func Listen(port uint16, grpcServer *grpc.Server, wg *sync.WaitGroup) { // CreateShutdownHandler returns a function that shuts down the grpcServer when the context is closed. // The server is given gracePeriod to perform a graceful showdown and is then forcably stopped if necessary -func CreateShutdownHandler(ctx context.Context, gracePeriod time.Duration, grpcServer *grpc.Server) func() error { +func CreateShutdownHandler(ctx *armadacontext.Context, gracePeriod time.Duration, grpcServer *grpc.Server) func() error { return func() error { <-ctx.Done() go func() { diff --git a/internal/common/healthmonitor/healthmonitor.go b/internal/common/healthmonitor/healthmonitor.go index aa196aaffda..d5c6b151c1e 100644 --- a/internal/common/healthmonitor/healthmonitor.go +++ b/internal/common/healthmonitor/healthmonitor.go @@ -1,10 +1,10 @@ package healthmonitor import ( - "context" - "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -25,5 +25,5 @@ type HealthMonitor interface { // Run initialises and starts the health checker. // Run may be blocking and should be run within a separate goroutine. // Must be called before IsHealthy() or any prometheus.Collector interface methods. - Run(context.Context, *logrus.Entry) error + Run(*armadacontext.Context, *logrus.Entry) error } diff --git a/internal/common/healthmonitor/manualhealthmonitor.go b/internal/common/healthmonitor/manualhealthmonitor.go index 1bc8a6d5b62..7aa2f525068 100644 --- a/internal/common/healthmonitor/manualhealthmonitor.go +++ b/internal/common/healthmonitor/manualhealthmonitor.go @@ -1,11 +1,12 @@ package healthmonitor import ( - "context" "sync" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // ManualHealthMonitor is a manually controlled health monitor. @@ -46,7 +47,7 @@ func (srv *ManualHealthMonitor) IsHealthy() (bool, string, error) { } } -func (srv *ManualHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { +func (srv *ManualHealthMonitor) Run(_ *armadacontext.Context, _ *logrus.Entry) error { return nil } diff --git a/internal/common/healthmonitor/multihealthmonitor.go b/internal/common/healthmonitor/multihealthmonitor.go index 8d9790fd91e..a9f03643d10 100644 --- a/internal/common/healthmonitor/multihealthmonitor.go +++ b/internal/common/healthmonitor/multihealthmonitor.go @@ -1,7 +1,6 @@ package healthmonitor import ( - "context" "fmt" "sync" @@ -9,7 +8,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" - "golang.org/x/sync/errgroup" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // MultiHealthMonitor wraps multiple HealthMonitors and itself implements the HealthMonitor interface. @@ -100,8 +100,8 @@ func (srv *MultiHealthMonitor) IsHealthy() (ok bool, reason string, err error) { } // Run initialises prometheus metrics and starts any child health checkers. -func (srv *MultiHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { - g, ctx := errgroup.WithContext(ctx) +func (srv *MultiHealthMonitor) Run(ctx *armadacontext.Context, log *logrus.Entry) error { + g, ctx := armadacontext.ErrGroup(ctx) for _, healthMonitor := range srv.healthMonitorsByName { healthMonitor := healthMonitor g.Go(func() error { return healthMonitor.Run(ctx, log) }) diff --git a/internal/common/ingest/batch.go b/internal/common/ingest/batch.go index 7f07c915855..f099f646fae 100644 --- a/internal/common/ingest/batch.go +++ b/internal/common/ingest/batch.go @@ -1,12 +1,13 @@ package ingest import ( - "context" "sync" "time" log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Batcher batches up events from a channel. Batches are created whenever maxItems have been @@ -32,7 +33,7 @@ func NewBatcher[T any](input chan T, maxItems int, maxTimeout time.Duration, cal } } -func (b *Batcher[T]) Run(ctx context.Context) { +func (b *Batcher[T]) Run(ctx *armadacontext.Context) { for { b.buffer = []T{} expire := b.clock.After(b.maxTimeout) diff --git a/internal/common/ingest/batch_test.go b/internal/common/ingest/batch_test.go index 4c9fee650a1..a906dbc8258 100644 --- a/internal/common/ingest/batch_test.go +++ b/internal/common/ingest/batch_test.go @@ -5,11 +5,11 @@ import ( "testing" "time" - "golang.org/x/net/context" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -42,7 +42,7 @@ func (r *resultHolder) resultLength() int { } func TestBatch_MaxItems(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -67,7 +67,7 @@ func TestBatch_MaxItems(t *testing.T) { } func TestBatch_Time(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -89,7 +89,7 @@ func TestBatch_Time(t *testing.T) { } func TestBatch_Time_WithIntialQuiet(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -120,7 +120,7 @@ func TestBatch_Time_WithIntialQuiet(t *testing.T) { cancel() } -func waitForBufferLength(ctx context.Context, batcher *Batcher[int], numEvents int) error { +func waitForBufferLength(ctx *armadacontext.Context, batcher *Batcher[int], numEvents int) error { ticker := time.NewTicker(5 * time.Millisecond) for { select { @@ -134,7 +134,7 @@ func waitForBufferLength(ctx context.Context, batcher *Batcher[int], numEvents i } } -func waitForExpectedEvents(ctx context.Context, rh *resultHolder, numEvents int) { +func waitForExpectedEvents(ctx *armadacontext.Context, rh *resultHolder, numEvents int) { done := false ticker := time.NewTicker(5 * time.Millisecond) for !done { diff --git a/internal/common/ingest/ingestion_pipeline.go b/internal/common/ingest/ingestion_pipeline.go index 2b5e9a9e783..4236473d360 100644 --- a/internal/common/ingest/ingestion_pipeline.go +++ b/internal/common/ingest/ingestion_pipeline.go @@ -1,16 +1,17 @@ package ingest import ( + "context" "sync" "time" "github.com/apache/pulsar-client-go/pulsar" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/net/context" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -27,7 +28,7 @@ type HasPulsarMessageIds interface { // InstructionConverter should be implemented by structs that can convert a batch of event sequences into an object // suitable for passing to the sink type InstructionConverter[T HasPulsarMessageIds] interface { - Convert(ctx context.Context, msg *EventSequencesWithIds) T + Convert(ctx *armadacontext.Context, msg *EventSequencesWithIds) T } // Sink should be implemented by the struct responsible for putting the data in its final resting place, e.g. a @@ -35,7 +36,7 @@ type InstructionConverter[T HasPulsarMessageIds] interface { type Sink[T HasPulsarMessageIds] interface { // Store should persist the sink. The store is responsible for retrying failed attempts and should only return an error // When it is satisfied that operation cannot be retries. - Store(ctx context.Context, msg T) error + Store(ctx *armadacontext.Context, msg T) error } // EventSequencesWithIds consists of a batch of Event Sequences along with the corresponding Pulsar Message Ids @@ -122,7 +123,7 @@ func NewFilteredMsgIngestionPipeline[T HasPulsarMessageIds]( } // Run will run the ingestion pipeline until the supplied context is shut down -func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { +func (ingester *IngestionPipeline[T]) Run(ctx *armadacontext.Context) error { shutdownMetricServer := common.ServeMetrics(ingester.metricsConfig.Port) defer shutdownMetricServer() @@ -147,7 +148,7 @@ func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { // Set up a context that n seconds after ctx // This gives the rest of the pipeline a chance to flush pending messages - pipelineShutdownContext, cancel := context.WithCancel(context.Background()) + pipelineShutdownContext, cancel := armadacontext.WithCancel(armadacontext.Background()) go func() { for { select { @@ -206,7 +207,7 @@ func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { } else { for _, msgId := range msg.GetMessageIDs() { util.RetryUntilSuccess( - context.Background(), + armadacontext.Background(), func() error { return ingester.consumer.AckID(msgId) }, func(err error) { log.WithError(err).Warnf("Pulsar ack failed; backing off for %s", ingester.pulsarConfig.BackoffTime) @@ -265,7 +266,7 @@ func unmarshalEventSequences(batch []pulsar.Message, msgFilter func(msg pulsar.M } // Try and unmarshall the proto - es, err := eventutil.UnmarshalEventSequence(context.Background(), msg.Payload()) + es, err := eventutil.UnmarshalEventSequence(armadacontext.Background(), msg.Payload()) if err != nil { metrics.RecordPulsarMessageError(commonmetrics.PulsarMessageErrorDeserialization) log.WithError(err).Warnf("Could not unmarshal proto for msg %s", msg.ID()) diff --git a/internal/common/ingest/ingestion_pipeline_test.go b/internal/common/ingest/ingestion_pipeline_test.go index da0d653b39a..53dd6a7a39b 100644 --- a/internal/common/ingest/ingestion_pipeline_test.go +++ b/internal/common/ingest/ingestion_pipeline_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/pkg/armadaevents" @@ -191,7 +192,7 @@ func newSimpleConverter(t *testing.T) InstructionConverter[*simpleMessages] { return &simpleConverter{t} } -func (s *simpleConverter) Convert(_ context.Context, msg *EventSequencesWithIds) *simpleMessages { +func (s *simpleConverter) Convert(_ *armadacontext.Context, msg *EventSequencesWithIds) *simpleMessages { s.t.Helper() assert.Len(s.t, msg.EventSequences, len(msg.MessageIds)) var converted []*simpleMessage @@ -218,7 +219,7 @@ func newSimpleSink(t *testing.T) *simpleSink { } } -func (s *simpleSink) Store(_ context.Context, msg *simpleMessages) error { +func (s *simpleSink) Store(_ *armadacontext.Context, msg *simpleMessages) error { for _, simpleMessage := range msg.msgs { s.simpleMessages[simpleMessage.id] = simpleMessage } @@ -236,7 +237,7 @@ func (s *simpleSink) assertDidProcess(messages []pulsar.Message) { } func TestRun_HappyPath_SingleMessage(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) + ctx, cancel := armadacontext.WithDeadline(armadacontext.Background(), time.Now().Add(10*time.Second)) messages := []pulsar.Message{ pulsarutils.NewPulsarMessage(1, baseTime, marshal(t, succeeded)), } @@ -257,7 +258,7 @@ func TestRun_HappyPath_SingleMessage(t *testing.T) { } func TestRun_HappyPath_MultipleMessages(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) + ctx, cancel := armadacontext.WithDeadline(armadacontext.Background(), time.Now().Add(10*time.Second)) messages := []pulsar.Message{ pulsarutils.NewPulsarMessage(1, baseTime, marshal(t, succeeded)), pulsarutils.NewPulsarMessage(2, baseTime.Add(1*time.Second), marshal(t, pendingAndRunning)), diff --git a/internal/common/pgkeyvalue/pgkeyvalue.go b/internal/common/pgkeyvalue/pgkeyvalue.go index 8476146d727..d3f5f7d9401 100644 --- a/internal/common/pgkeyvalue/pgkeyvalue.go +++ b/internal/common/pgkeyvalue/pgkeyvalue.go @@ -1,7 +1,6 @@ package pgkeyvalue import ( - "context" "fmt" "time" @@ -10,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/logging" @@ -34,7 +34,7 @@ type PGKeyValueStore struct { clock clock.Clock } -func New(ctx context.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueStore, error) { +func New(ctx *armadacontext.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueStore, error) { if db == nil { return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ Name: "db", @@ -60,7 +60,7 @@ func New(ctx context.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueSt }, nil } -func (c *PGKeyValueStore) Load(ctx context.Context, keys []string) (map[string][]byte, error) { +func (c *PGKeyValueStore) Load(ctx *armadacontext.Context, keys []string) (map[string][]byte, error) { rows, err := c.db.Query(ctx, fmt.Sprintf("SELECT KEY, VALUE FROM %s WHERE KEY = any($1)", c.tableName), keys) if err != nil { return nil, errors.WithStack(err) @@ -78,7 +78,7 @@ func (c *PGKeyValueStore) Load(ctx context.Context, keys []string) (map[string][ return kv, nil } -func (c *PGKeyValueStore) Store(ctx context.Context, kvs map[string][]byte) error { +func (c *PGKeyValueStore) Store(ctx *armadacontext.Context, kvs map[string][]byte) error { data := make([]KeyValue, 0, len(kvs)) for k, v := range kvs { data = append(data, KeyValue{ @@ -90,7 +90,7 @@ func (c *PGKeyValueStore) Store(ctx context.Context, kvs map[string][]byte) erro return database.UpsertWithTransaction(ctx, c.db, c.tableName, data) } -func createTableIfNotExists(ctx context.Context, db *pgxpool.Pool, tableName string) error { +func createTableIfNotExists(ctx *armadacontext.Context, db *pgxpool.Pool, tableName string) error { _, err := db.Exec(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( key TEXT PRIMARY KEY, @@ -101,7 +101,7 @@ func createTableIfNotExists(ctx context.Context, db *pgxpool.Pool, tableName str } // Cleanup removes all key-value pairs older than lifespan. -func (c *PGKeyValueStore) cleanup(ctx context.Context, lifespan time.Duration) error { +func (c *PGKeyValueStore) cleanup(ctx *armadacontext.Context, lifespan time.Duration) error { sql := fmt.Sprintf("DELETE FROM %s WHERE (inserted <= $1);", c.tableName) _, err := c.db.Exec(ctx, sql, c.clock.Now().Add(-lifespan)) if err != nil { @@ -112,7 +112,7 @@ func (c *PGKeyValueStore) cleanup(ctx context.Context, lifespan time.Duration) e // PeriodicCleanup starts a goroutine that automatically runs the cleanup job // every interval until the provided context is cancelled. -func (c *PGKeyValueStore) PeriodicCleanup(ctx context.Context, interval time.Duration, lifespan time.Duration) error { +func (c *PGKeyValueStore) PeriodicCleanup(ctx *armadacontext.Context, interval time.Duration, lifespan time.Duration) error { log := logrus.StandardLogger().WithField("service", "PGKeyValueStoreCleanup") log.Info("service started") ticker := c.clock.NewTicker(interval) diff --git a/internal/common/pgkeyvalue/pgkeyvalue_test.go b/internal/common/pgkeyvalue/pgkeyvalue_test.go index c8a9beeb175..aa70c4ed7b9 100644 --- a/internal/common/pgkeyvalue/pgkeyvalue_test.go +++ b/internal/common/pgkeyvalue/pgkeyvalue_test.go @@ -1,7 +1,6 @@ package pgkeyvalue import ( - "context" "testing" "time" @@ -11,11 +10,12 @@ import ( "golang.org/x/exp/maps" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/lookout/testutil" ) func TestLoadStore(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { kvStore, err := New(ctx, db, "cachetable") @@ -47,7 +47,7 @@ func TestLoadStore(t *testing.T) { } func TestCleanup(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { baseTime := time.Now() diff --git a/internal/common/pulsarutils/async.go b/internal/common/pulsarutils/async.go index 8f71781d558..9040eed5fe9 100644 --- a/internal/common/pulsarutils/async.go +++ b/internal/common/pulsarutils/async.go @@ -7,11 +7,11 @@ import ( "sync" "time" - commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/apache/pulsar-client-go/pulsar" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" + commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/util" ) @@ -36,7 +36,7 @@ type ConsumerMessage struct { var msgLogger = logrus.NewEntry(logrus.StandardLogger()) func Receive( - ctx context.Context, + ctx *armadacontext.Context, consumer pulsar.Consumer, receiveTimeout time.Duration, backoffTime time.Duration, @@ -76,7 +76,7 @@ func Receive( return default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, receiveTimeout) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, receiveTimeout) msg, err := consumer.Receive(ctxWithTimeout) if errors.Is(err, context.DeadlineExceeded) { msgLogger.Debugf("No message received") @@ -109,7 +109,7 @@ func Receive( // Ack will ack all pulsar messages coming in on the msgs channel. The incoming messages contain a consumer id which // corresponds to the index of the consumer that should be used to perform the ack. In theory, the acks could be done // in parallel, however its unlikely that they will be a performance bottleneck -func Ack(ctx context.Context, consumers []pulsar.Consumer, msgs chan []*ConsumerMessageId, backoffTime time.Duration, wg *sync.WaitGroup) { +func Ack(ctx *armadacontext.Context, consumers []pulsar.Consumer, msgs chan []*ConsumerMessageId, backoffTime time.Duration, wg *sync.WaitGroup) { for msg := range msgs { for _, id := range msg { if id.ConsumerId < 0 || id.ConsumerId >= len(consumers) { diff --git a/internal/common/pulsarutils/async_test.go b/internal/common/pulsarutils/async_test.go index d47151c660d..bb8739254df 100644 --- a/internal/common/pulsarutils/async_test.go +++ b/internal/common/pulsarutils/async_test.go @@ -1,16 +1,16 @@ package pulsarutils import ( - ctx "context" + "context" "sync" "testing" "time" - "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/apache/pulsar-client-go/pulsar" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" + + "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/ingest/metrics" ) var m = metrics.NewMetrics("test_pulsarutils_") @@ -46,8 +46,8 @@ func TestReceive(t *testing.T) { consumer := &mockConsumer{ msgs: msgs, } - context, cancel := ctx.WithCancel(ctx.Background()) - outputChan := Receive(context, consumer, 10*time.Millisecond, 10*time.Millisecond, m) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) + outputChan := Receive(ctx, consumer, 10*time.Millisecond, 10*time.Millisecond, m) var receivedMsgs []pulsar.Message wg := sync.WaitGroup{} @@ -71,7 +71,7 @@ func TestAcks(t *testing.T) { consumers := []pulsar.Consumer{&mockConsumer} wg := sync.WaitGroup{} wg.Add(1) - go Ack(ctx.Background(), consumers, input, 1*time.Second, &wg) + go Ack(armadacontext.Background(), consumers, input, 1*time.Second, &wg) input <- []*ConsumerMessageId{ {NewMessageId(1), 0, 0}, {NewMessageId(2), 0, 0}, } diff --git a/internal/common/pulsarutils/eventsequence.go b/internal/common/pulsarutils/eventsequence.go index 49325bd0b2b..3750a1b11e8 100644 --- a/internal/common/pulsarutils/eventsequence.go +++ b/internal/common/pulsarutils/eventsequence.go @@ -1,24 +1,23 @@ package pulsarutils import ( - "context" "sync/atomic" - "github.com/armadaproject/armada/internal/common/schedulers" - "github.com/apache/pulsar-client-go/pulsar" "github.com/gogo/protobuf/proto" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/requestid" + "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" ) // CompactAndPublishSequences reduces the number of sequences to the smallest possible, // while respecting per-job set ordering and max Pulsar message size, and then publishes to Pulsar. -func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint, scheduler schedulers.Scheduler) error { +func CompactAndPublishSequences(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint, scheduler schedulers.Scheduler) error { // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than maxMessageSizeInBytes. sequences = eventutil.CompactEventSequences(sequences) @@ -38,7 +37,7 @@ func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.E // and // eventutil.LimitSequencesByteSize(sequences, int(srv.MaxAllowedMessageSize)) // before passing to this function. -func PublishSequences(ctx context.Context, producer pulsar.Producer, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { +func PublishSequences(ctx *armadacontext.Context, producer pulsar.Producer, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { // Incoming gRPC requests are annotated with a unique id. // Pass this id through the log by adding it to the Pulsar message properties. requestId := requestid.FromContextOrMissing(ctx) diff --git a/internal/common/pulsarutils/eventsequence_test.go b/internal/common/pulsarutils/eventsequence_test.go index 0613a9f3462..0832195beac 100644 --- a/internal/common/pulsarutils/eventsequence_test.go +++ b/internal/common/pulsarutils/eventsequence_test.go @@ -9,19 +9,20 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" ) func TestPublishSequences_SendAsyncErr(t *testing.T) { producer := &mockProducer{} - err := PublishSequences(context.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) + err := PublishSequences(armadacontext.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.NoError(t, err) producer = &mockProducer{ sendAsyncErr: errors.New("sendAsyncErr"), } - err = PublishSequences(context.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) + err = PublishSequences(armadacontext.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.ErrorIs(t, err, producer.sendAsyncErr) } @@ -29,7 +30,7 @@ func TestPublishSequences_RespectTimeout(t *testing.T) { producer := &mockProducer{ sendAsyncDuration: 1 * time.Second, } - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond) defer cancel() err := PublishSequences(ctx, producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.ErrorIs(t, err, context.DeadlineExceeded) diff --git a/internal/common/startup.go b/internal/common/startup.go index e14fa5a21a7..eebc23cdd3d 100644 --- a/internal/common/startup.go +++ b/internal/common/startup.go @@ -1,7 +1,6 @@ package common import ( - "context" "fmt" "net/http" "os" @@ -18,6 +17,7 @@ import ( "github.com/spf13/viper" "github.com/weaveworks/promrus" + "github.com/armadaproject/armada/internal/common/armadacontext" commonconfig "github.com/armadaproject/armada/internal/common/config" "github.com/armadaproject/armada/internal/common/logging" ) @@ -159,7 +159,7 @@ func ServeHttp(port uint16, mux http.Handler) (shutdown func()) { // TODO There's no need for this function to panic, since the main goroutine will exit. // Instead, just log an error. return func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() log.Printf("Stopping http server listening on %d", port) e := srv.Shutdown(ctx) diff --git a/internal/common/util/context.go b/internal/common/util/context.go index c96b4f0adee..1f6fa6519f4 100644 --- a/internal/common/util/context.go +++ b/internal/common/util/context.go @@ -1,11 +1,12 @@ package util import ( - "context" "time" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func CloseToDeadline(ctx context.Context, tolerance time.Duration) bool { +func CloseToDeadline(ctx *armadacontext.Context, tolerance time.Duration) bool { deadline, exists := ctx.Deadline() return exists && deadline.Before(time.Now().Add(tolerance)) } diff --git a/internal/common/util/retry.go b/internal/common/util/retry.go index 9f178c037d8..c688614e63e 100644 --- a/internal/common/util/retry.go +++ b/internal/common/util/retry.go @@ -1,8 +1,10 @@ package util -import "golang.org/x/net/context" +import ( + "github.com/armadaproject/armada/internal/common/armadacontext" +) -func RetryUntilSuccess(ctx context.Context, performAction func() error, onError func(error)) { +func RetryUntilSuccess(ctx *armadacontext.Context, performAction func() error, onError func(error)) { for { select { case <-ctx.Done(): diff --git a/internal/common/util/retry_test.go b/internal/common/util/retry_test.go index 43180ac6f39..2ad6ea4b300 100644 --- a/internal/common/util/retry_test.go +++ b/internal/common/util/retry_test.go @@ -1,16 +1,17 @@ package util import ( - "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) func TestRetryDoesntSpin(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( @@ -30,7 +31,7 @@ func TestRetryDoesntSpin(t *testing.T) { } func TestRetryCancel(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( @@ -61,7 +62,7 @@ func TestSucceedsAfterFailures(t *testing.T) { errorCount := 0 - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( diff --git a/internal/eventingester/convert/conversions.go b/internal/eventingester/convert/conversions.go index cab978e5812..fbb66a0c481 100644 --- a/internal/eventingester/convert/conversions.go +++ b/internal/eventingester/convert/conversions.go @@ -1,12 +1,11 @@ package convert import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -30,7 +29,7 @@ func NewEventConverter(compressor compress.Compressor, maxMessageBatchSize uint, } } -func (ec *EventConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.BatchUpdate { +func (ec *EventConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.BatchUpdate { // Remove all groups as they are potentially quite large for _, es := range sequencesWithIds.EventSequences { es.Groups = nil diff --git a/internal/eventingester/convert/conversions_test.go b/internal/eventingester/convert/conversions_test.go index c716f84815a..24ff9013733 100644 --- a/internal/eventingester/convert/conversions_test.go +++ b/internal/eventingester/convert/conversions_test.go @@ -1,7 +1,6 @@ package convert import ( - "context" "math/rand" "testing" "time" @@ -11,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -55,7 +55,7 @@ var cancelled = &armadaevents.EventSequence_Event{ func TestSingle(t *testing.T) { msg := NewMsg(jobRunSucceeded) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) expectedSequence := armadaevents.EventSequence{ Events: []*armadaevents.EventSequence_Event{jobRunSucceeded}, } @@ -72,7 +72,7 @@ func TestSingle(t *testing.T) { func TestMultiple(t *testing.T) { msg := NewMsg(cancelled, jobRunSucceeded) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) expectedSequence := armadaevents.EventSequence{ Events: []*armadaevents.EventSequence_Event{cancelled, jobRunSucceeded}, } @@ -113,7 +113,7 @@ func TestCancelled(t *testing.T) { }, }) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) assert.Equal(t, 1, len(batchUpdate.Events)) event := batchUpdate.Events[0] es, err := extractEventSeq(event.Event) diff --git a/internal/eventingester/store/eventstore.go b/internal/eventingester/store/eventstore.go index 2f9dc7555a2..981e8460c16 100644 --- a/internal/eventingester/store/eventstore.go +++ b/internal/eventingester/store/eventstore.go @@ -1,7 +1,6 @@ package store import ( - "context" "regexp" "time" @@ -9,6 +8,7 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/eventingester/configuration" "github.com/armadaproject/armada/internal/eventingester/model" @@ -39,7 +39,7 @@ func NewRedisEventStore(db redis.UniversalClient, eventRetention configuration.E } } -func (repo *RedisEventStore) Store(ctx context.Context, update *model.BatchUpdate) error { +func (repo *RedisEventStore) Store(ctx *armadacontext.Context, update *model.BatchUpdate) error { if len(update.Events) == 0 { return nil } diff --git a/internal/eventingester/store/eventstore_test.go b/internal/eventingester/store/eventstore_test.go index 3327a4fff95..1584b56ba15 100644 --- a/internal/eventingester/store/eventstore_test.go +++ b/internal/eventingester/store/eventstore_test.go @@ -1,13 +1,13 @@ package store import ( - "context" "testing" "time" "github.com/go-redis/redis" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/eventingester/configuration" "github.com/armadaproject/armada/internal/eventingester/model" ) @@ -29,7 +29,7 @@ func TestReportEvents(t *testing.T) { }, } - err := r.Store(context.Background(), update) + err := r.Store(armadacontext.Background(), update) assert.NoError(t, err) read1, err := ReadEvent(r.db, "testQueue", "testJobset") diff --git a/internal/executor/application.go b/internal/executor/application.go index 6a15c0f9414..3cb4db15af0 100644 --- a/internal/executor/application.go +++ b/internal/executor/application.go @@ -1,7 +1,6 @@ package executor import ( - "context" "fmt" "net/http" "os" @@ -14,10 +13,10 @@ import ( grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/cluster" "github.com/armadaproject/armada/internal/common/etcdhealth" "github.com/armadaproject/armada/internal/common/healthmonitor" @@ -41,7 +40,7 @@ import ( "github.com/armadaproject/armada/pkg/executorapi" ) -func StartUp(ctx context.Context, log *logrus.Entry, config configuration.ExecutorConfiguration) (func(), *sync.WaitGroup) { +func StartUp(ctx *armadacontext.Context, log *logrus.Entry, config configuration.ExecutorConfiguration) (func(), *sync.WaitGroup) { err := validateConfig(config) if err != nil { log.Errorf("Invalid config: %s", err) @@ -59,7 +58,7 @@ func StartUp(ctx context.Context, log *logrus.Entry, config configuration.Execut } // Create an errgroup to run services in. - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) // Setup etcd health monitoring. etcdClusterHealthMonitoringByName := make(map[string]healthmonitor.HealthMonitor, len(config.Kubernetes.Etcd.EtcdClustersHealthMonitoring)) diff --git a/internal/executor/context/cluster_context.go b/internal/executor/context/cluster_context.go index 79619ea06fd..555303fe9a3 100644 --- a/internal/executor/context/cluster_context.go +++ b/internal/executor/context/cluster_context.go @@ -1,7 +1,6 @@ package context import ( - "context" "encoding/json" "fmt" "time" @@ -26,6 +25,7 @@ import ( "k8s.io/kubelet/pkg/apis/stats/v1alpha1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/cluster" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -50,7 +50,7 @@ type ClusterContext interface { GetActiveBatchPods() ([]*v1.Pod, error) GetNodes() ([]*v1.Node, error) GetNode(nodeName string) (*v1.Node, error) - GetNodeStatsSummary(context.Context, *v1.Node) (*v1alpha1.Summary, error) + GetNodeStatsSummary(*armadacontext.Context, *v1.Node) (*v1alpha1.Summary, error) GetPodEvents(pod *v1.Pod) ([]*v1.Event, error) GetServices(pod *v1.Pod) ([]*v1.Service, error) GetIngresses(pod *v1.Pod) ([]*networking.Ingress, error) @@ -223,7 +223,7 @@ func (c *KubernetesClusterContext) GetNode(nodeName string) (*v1.Node, error) { return c.nodeInformer.Lister().Get(nodeName) } -func (c *KubernetesClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *KubernetesClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { request := c.kubernetesClient. CoreV1(). RESTClient(). @@ -253,7 +253,7 @@ func (c *KubernetesClusterContext) SubmitPod(pod *v1.Pod, owner string, ownerGro return nil, err } - returnedPod, err := ownerClient.CoreV1().Pods(pod.Namespace).Create(context.Background(), pod, metav1.CreateOptions{}) + returnedPod, err := ownerClient.CoreV1().Pods(pod.Namespace).Create(armadacontext.Background(), pod, metav1.CreateOptions{}) if err != nil { c.submittedPods.Delete(util.ExtractPodKey(pod)) } @@ -261,11 +261,11 @@ func (c *KubernetesClusterContext) SubmitPod(pod *v1.Pod, owner string, ownerGro } func (c *KubernetesClusterContext) SubmitService(service *v1.Service) (*v1.Service, error) { - return c.kubernetesClient.CoreV1().Services(service.Namespace).Create(context.Background(), service, metav1.CreateOptions{}) + return c.kubernetesClient.CoreV1().Services(service.Namespace).Create(armadacontext.Background(), service, metav1.CreateOptions{}) } func (c *KubernetesClusterContext) SubmitIngress(ingress *networking.Ingress) (*networking.Ingress, error) { - return c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Create(context.Background(), ingress, metav1.CreateOptions{}) + return c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Create(armadacontext.Background(), ingress, metav1.CreateOptions{}) } func (c *KubernetesClusterContext) AddAnnotation(pod *v1.Pod, annotations map[string]string) error { @@ -280,7 +280,7 @@ func (c *KubernetesClusterContext) AddAnnotation(pod *v1.Pod, annotations map[st } _, err = c.kubernetesClient.CoreV1(). Pods(pod.Namespace). - Patch(context.Background(), pod.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) + Patch(armadacontext.Background(), pod.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) if err != nil { return err } @@ -299,7 +299,7 @@ func (c *KubernetesClusterContext) AddClusterEventAnnotation(event *v1.Event, an } _, err = c.kubernetesClient.CoreV1(). Events(event.Namespace). - Patch(context.Background(), event.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) + Patch(armadacontext.Background(), event.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) if err != nil { return err } @@ -318,7 +318,7 @@ func (c *KubernetesClusterContext) DeletePodWithCondition(pod *v1.Pod, condition return err } // Get latest pod state - bypassing cache - timeout, cancel := context.WithTimeout(context.Background(), time.Second*10) + timeout, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Second*10) defer cancel() currentPod, err = c.kubernetesClient.CoreV1().Pods(currentPod.Namespace).Get(timeout, currentPod.Name, metav1.GetOptions{}) if err != nil { @@ -368,7 +368,7 @@ func (c *KubernetesClusterContext) DeletePods(pods []*v1.Pod) { func (c *KubernetesClusterContext) DeleteService(service *v1.Service) error { deleteOptions := createDeleteOptions() - err := c.kubernetesClient.CoreV1().Services(service.Namespace).Delete(context.Background(), service.Name, deleteOptions) + err := c.kubernetesClient.CoreV1().Services(service.Namespace).Delete(armadacontext.Background(), service.Name, deleteOptions) if err != nil && k8s_errors.IsNotFound(err) { return nil } @@ -377,7 +377,7 @@ func (c *KubernetesClusterContext) DeleteService(service *v1.Service) error { func (c *KubernetesClusterContext) DeleteIngress(ingress *networking.Ingress) error { deleteOptions := createDeleteOptions() - err := c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Delete(context.Background(), ingress.Name, deleteOptions) + err := c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Delete(armadacontext.Background(), ingress.Name, deleteOptions) if err != nil && k8s_errors.IsNotFound(err) { return nil } @@ -386,7 +386,7 @@ func (c *KubernetesClusterContext) DeleteIngress(ingress *networking.Ingress) er func (c *KubernetesClusterContext) ProcessPodsToDelete() { pods := c.podsToDelete.GetAll() - util.ProcessItemsWithThreadPool(context.Background(), c.deleteThreadCount, pods, func(podToDelete *v1.Pod) { + util.ProcessItemsWithThreadPool(armadacontext.Background(), c.deleteThreadCount, pods, func(podToDelete *v1.Pod) { if podToDelete == nil { return } @@ -438,7 +438,7 @@ func (c *KubernetesClusterContext) doDelete(pod *v1.Pod, force bool) { } func (c *KubernetesClusterContext) deletePod(pod *v1.Pod, deleteOptions metav1.DeleteOptions) error { - return c.kubernetesClient.CoreV1().Pods(pod.Namespace).Delete(context.Background(), pod.Name, deleteOptions) + return c.kubernetesClient.CoreV1().Pods(pod.Namespace).Delete(armadacontext.Background(), pod.Name, deleteOptions) } func (c *KubernetesClusterContext) markForDeletion(pod *v1.Pod) (*v1.Pod, error) { diff --git a/internal/executor/context/cluster_context_test.go b/internal/executor/context/cluster_context_test.go index d1836e82168..b382cd0e690 100644 --- a/internal/executor/context/cluster_context_test.go +++ b/internal/executor/context/cluster_context_test.go @@ -1,7 +1,6 @@ package context import ( - ctx "context" "encoding/json" "errors" "testing" @@ -23,6 +22,7 @@ import ( clientTesting "k8s.io/client-go/testing" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" "github.com/armadaproject/armada/internal/executor/domain" @@ -699,7 +699,7 @@ func TestKubernetesClusterContext_GetNodes(t *testing.T) { }, } - _, err := client.CoreV1().Nodes().Create(ctx.Background(), node, metav1.CreateOptions{}) + _, err := client.CoreV1().Nodes().Create(armadacontext.Background(), node, metav1.CreateOptions{}) assert.Nil(t, err) nodeFound := waitForCondition(func() bool { diff --git a/internal/executor/context/fake/sync_cluster_context.go b/internal/executor/context/fake/sync_cluster_context.go index 7a8d26797d0..d4a178920d0 100644 --- a/internal/executor/context/fake/sync_cluster_context.go +++ b/internal/executor/context/fake/sync_cluster_context.go @@ -1,7 +1,6 @@ package fake import ( - "context" "errors" "fmt" @@ -11,6 +10,7 @@ import ( "k8s.io/client-go/tools/cache" "k8s.io/kubelet/pkg/apis/stats/v1alpha1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/domain" ) @@ -132,7 +132,7 @@ func (c *SyncFakeClusterContext) GetClusterPool() string { return "pool" } -func (c *SyncFakeClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *SyncFakeClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { return &v1alpha1.Summary{}, nil } diff --git a/internal/executor/fake/context/context.go b/internal/executor/fake/context/context.go index 0cee687458b..906c23fe85f 100644 --- a/internal/executor/fake/context/context.go +++ b/internal/executor/fake/context/context.go @@ -1,7 +1,6 @@ package context import ( - "context" "fmt" "math/rand" "regexp" @@ -23,6 +22,7 @@ import ( "k8s.io/client-go/tools/cache" "k8s.io/kubelet/pkg/apis/stats/v1alpha1" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -314,7 +314,7 @@ func (c *FakeClusterContext) GetClusterPool() string { return c.pool } -func (c *FakeClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *FakeClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { return &v1alpha1.Summary{}, nil } diff --git a/internal/executor/job/job_context.go b/internal/executor/job/job_context.go index 3cc8b36f2b3..bcc5526ce2e 100644 --- a/internal/executor/job/job_context.go +++ b/internal/executor/job/job_context.go @@ -1,7 +1,6 @@ package job import ( - "context" "fmt" "sync" "time" @@ -10,6 +9,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/client-go/tools/cache" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/podchecks" "github.com/armadaproject/armada/internal/executor/util" @@ -149,7 +149,7 @@ func (c *ClusterJobContext) AddAnnotation(jobs []*RunningJob, annotations map[st } } - util.ProcessItemsWithThreadPool(context.Background(), c.updateThreadCount, podsToAnnotate, + util.ProcessItemsWithThreadPool(armadacontext.Background(), c.updateThreadCount, podsToAnnotate, func(pod *v1.Pod) { err := c.clusterContext.AddAnnotation(pod, annotations) if err != nil { diff --git a/internal/executor/job/processors/preempt_runs.go b/internal/executor/job/processors/preempt_runs.go index 9e48adb71a6..c296f7b75f7 100644 --- a/internal/executor/job/processors/preempt_runs.go +++ b/internal/executor/job/processors/preempt_runs.go @@ -1,13 +1,13 @@ package processors import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -46,7 +46,7 @@ func (j *RunPreemptedProcessor) Run() { }) runPodInfos := createRunPodInfos(runsToCancel, managedPods) - util.ProcessItemsWithThreadPool(context.Background(), 20, runPodInfos, + util.ProcessItemsWithThreadPool(armadacontext.Background(), 20, runPodInfos, func(runInfo *runPodInfo) { pod := runInfo.Pod if pod == nil { diff --git a/internal/executor/job/processors/remove_runs.go b/internal/executor/job/processors/remove_runs.go index 83038d8c1e1..37942110605 100644 --- a/internal/executor/job/processors/remove_runs.go +++ b/internal/executor/job/processors/remove_runs.go @@ -1,12 +1,12 @@ package processors import ( - "context" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -37,7 +37,7 @@ func (j *RemoveRunProcessor) Run() { }) runPodInfos := createRunPodInfos(runsToCancel, managedPods) - util.ProcessItemsWithThreadPool(context.Background(), 20, runPodInfos, + util.ProcessItemsWithThreadPool(armadacontext.Background(), 20, runPodInfos, func(runInfo *runPodInfo) { pod := runInfo.Pod if pod == nil { diff --git a/internal/executor/reporter/event_sender.go b/internal/executor/reporter/event_sender.go index 9dd42a03f9d..d9afe0fa48b 100644 --- a/internal/executor/reporter/event_sender.go +++ b/internal/executor/reporter/event_sender.go @@ -1,13 +1,12 @@ package reporter import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" @@ -63,7 +62,7 @@ func (eventSender *ExecutorApiEventSender) SendEvents(events []EventMessage) err } for _, eventList := range eventLists { - _, err = eventSender.eventClient.ReportEvents(context.Background(), eventList) + _, err = eventSender.eventClient.ReportEvents(armadacontext.Background(), eventList) if err != nil { return err } diff --git a/internal/executor/reporter/event_sender_test.go b/internal/executor/reporter/event_sender_test.go index 1c91d1cd6f0..08e60daa521 100644 --- a/internal/executor/reporter/event_sender_test.go +++ b/internal/executor/reporter/event_sender_test.go @@ -205,13 +205,13 @@ func newFakeExecutorApiClient() *fakeExecutorApiClient { } } -func (fakeClient *fakeExecutorApiClient) LeaseJobRuns(ctx context.Context, opts ...grpc.CallOption) (executorapi.ExecutorApi_LeaseJobRunsClient, error) { +func (fakeClient *fakeExecutorApiClient) LeaseJobRuns(_ context.Context, opts ...grpc.CallOption) (executorapi.ExecutorApi_LeaseJobRunsClient, error) { // Not implemented return nil, nil } // Reports job run events to the scheduler -func (fakeClient *fakeExecutorApiClient) ReportEvents(ctx context.Context, in *executorapi.EventList, opts ...grpc.CallOption) (*types.Empty, error) { +func (fakeClient *fakeExecutorApiClient) ReportEvents(_ context.Context, in *executorapi.EventList, opts ...grpc.CallOption) (*types.Empty, error) { fakeClient.reportedEvents = append(fakeClient.reportedEvents, in) return nil, nil } diff --git a/internal/executor/service/job_lease.go b/internal/executor/service/job_lease.go index d8165c32a1b..1b18fc0c9d2 100644 --- a/internal/executor/service/job_lease.go +++ b/internal/executor/service/job_lease.go @@ -10,7 +10,6 @@ import ( grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" v1 "k8s.io/api/core/v1" @@ -18,6 +17,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaresource "github.com/armadaproject/armada/internal/common/resource" commonUtil "github.com/armadaproject/armada/internal/common/util" @@ -111,10 +111,10 @@ func (jobLeaseService *JobLeaseService) requestJobLeases(leaseRequest *api.Strea // Setup a bidirectional gRPC stream. // The server sends jobs over this stream. // The executor sends back acks to indicate which jobs were successfully received. - ctx := context.Background() + ctx := armadacontext.Background() var cancel context.CancelFunc if jobLeaseService.jobLeaseRequestTimeout != 0 { - ctx, cancel = context.WithTimeout(ctx, jobLeaseService.jobLeaseRequestTimeout) + ctx, cancel = armadacontext.WithTimeout(ctx, jobLeaseService.jobLeaseRequestTimeout) defer cancel() } stream, err := jobLeaseService.queueClient.StreamingLeaseJobs(ctx, grpc_retry.Disable(), grpc.UseCompressor(gzip.Name)) @@ -137,7 +137,7 @@ func (jobLeaseService *JobLeaseService) requestJobLeases(leaseRequest *api.Strea var numJobs uint32 jobs := make([]*api.Job, 0) ch := make(chan *api.StreamingJobLease, 10) - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) g.Go(func() error { // Close channel to ensure sending goroutine exits. defer close(ch) diff --git a/internal/executor/service/job_manager.go b/internal/executor/service/job_manager.go index 4b8b1cfe016..496440d0538 100644 --- a/internal/executor/service/job_manager.go +++ b/internal/executor/service/job_manager.go @@ -1,13 +1,13 @@ package service import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" context2 "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -75,7 +75,7 @@ func (m *JobManager) ManageJobLeases() { } } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Minute*2) defer cancel() m.handlePodIssues(ctx, jobs) } @@ -108,7 +108,7 @@ func (m *JobManager) reportTerminated(pods []*v1.Pod) { } } -func (m *JobManager) handlePodIssues(ctx context.Context, allRunningJobs []*job.RunningJob) { +func (m *JobManager) handlePodIssues(ctx *armadacontext.Context, allRunningJobs []*job.RunningJob) { util.ProcessItemsWithThreadPool(ctx, 20, allRunningJobs, m.handlePodIssue) } diff --git a/internal/executor/service/job_requester.go b/internal/executor/service/job_requester.go index 217f279639e..53cf83c49a6 100644 --- a/internal/executor/service/job_requester.go +++ b/internal/executor/service/job_requester.go @@ -1,12 +1,12 @@ package service import ( - "context" "time" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/slices" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -56,7 +56,7 @@ func (r *JobRequester) RequestJobsRuns() { log.Errorf("Failed to create lease request because %s", err) return } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() leaseResponse, err := r.leaseRequester.LeaseJobRuns(ctx, leaseRequest) if err != nil { diff --git a/internal/executor/service/job_requester_test.go b/internal/executor/service/job_requester_test.go index 532e7e4fb0e..f7e3fcbc5b7 100644 --- a/internal/executor/service/job_requester_test.go +++ b/internal/executor/service/job_requester_test.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "testing" @@ -11,6 +10,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -275,7 +275,7 @@ type StubLeaseRequester struct { LeaseJobRunLeaseResponse *LeaseResponse } -func (s *StubLeaseRequester) LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) { +func (s *StubLeaseRequester) LeaseJobRuns(_ *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) { s.ReceivedLeaseRequests = append(s.ReceivedLeaseRequests, request) return s.LeaseJobRunLeaseResponse, s.LeaseJobRunError } diff --git a/internal/executor/service/lease_requester.go b/internal/executor/service/lease_requester.go index 36a29f6e4f2..dc4976d84b5 100644 --- a/internal/executor/service/lease_requester.go +++ b/internal/executor/service/lease_requester.go @@ -10,6 +10,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" clusterContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/pkg/api" @@ -31,7 +32,7 @@ type LeaseResponse struct { } type LeaseRequester interface { - LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) + LeaseJobRuns(ctx *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) } type JobLeaseRequester struct { @@ -52,7 +53,7 @@ func NewJobLeaseRequester( } } -func (requester *JobLeaseRequester) LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) { +func (requester *JobLeaseRequester) LeaseJobRuns(ctx *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) { stream, err := requester.executorApiClient.LeaseJobRuns(ctx, grpcretry.Disable(), grpc.UseCompressor(gzip.Name)) if err != nil { return nil, err diff --git a/internal/executor/service/lease_requester_test.go b/internal/executor/service/lease_requester_test.go index 3f09cf450a7..f6314876c9f 100644 --- a/internal/executor/service/lease_requester_test.go +++ b/internal/executor/service/lease_requester_test.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "io" "testing" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/mocks" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/executor/context/fake" @@ -39,7 +39,7 @@ var ( ) func TestLeaseJobRuns(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() tests := map[string]struct { leaseMessages []*executorapi.JobRunLease @@ -87,7 +87,7 @@ func TestLeaseJobRuns(t *testing.T) { } func TestLeaseJobRuns_Send(t *testing.T) { - shortCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + shortCtx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() leaseRequest := &LeaseRequest{ @@ -126,7 +126,7 @@ func TestLeaseJobRuns_Send(t *testing.T) { func TestLeaseJobRuns_HandlesNoEndMarkerMessage(t *testing.T) { leaseMessages := []*executorapi.JobRunLease{lease1, lease2} - shortCtx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + shortCtx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 200*time.Millisecond) defer cancel() jobRequester, mockExecutorApiClient, mockStream := setup(t) @@ -146,7 +146,7 @@ func TestLeaseJobRuns_HandlesNoEndMarkerMessage(t *testing.T) { } func TestLeaseJobRuns_Error(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() tests := map[string]struct { streamError bool diff --git a/internal/executor/service/pod_issue_handler.go b/internal/executor/service/pod_issue_handler.go index b98980df01b..57b323e7146 100644 --- a/internal/executor/service/pod_issue_handler.go +++ b/internal/executor/service/pod_issue_handler.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "k8s.io/client-go/tools/cache" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/executor/configuration" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/job" @@ -159,7 +159,7 @@ func (p *IssueHandler) HandlePodIssues() { }) p.detectPodIssues(managedPods) p.detectReconciliationIssues(managedPods) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Minute*2) defer cancel() p.handleKnownIssues(ctx, managedPods) } @@ -225,7 +225,7 @@ func (p *IssueHandler) detectPodIssues(allManagedPods []*v1.Pod) { } } -func (p *IssueHandler) handleKnownIssues(ctx context.Context, allManagedPods []*v1.Pod) { +func (p *IssueHandler) handleKnownIssues(ctx *armadacontext.Context, allManagedPods []*v1.Pod) { // Make issues from pods + issues issues := createIssues(allManagedPods, p.knownPodIssues) util.ProcessItemsWithThreadPool(ctx, 20, issues, p.handleRunIssue) diff --git a/internal/executor/util/process.go b/internal/executor/util/process.go index cc4da52d9a2..a38c316b5fa 100644 --- a/internal/executor/util/process.go +++ b/internal/executor/util/process.go @@ -1,13 +1,13 @@ package util import ( - "context" "sync" + "github.com/armadaproject/armada/internal/common/armadacontext" commonUtil "github.com/armadaproject/armada/internal/common/util" ) -func ProcessItemsWithThreadPool[K any](ctx context.Context, maxThreadCount int, itemsToProcess []K, processFunc func(K)) { +func ProcessItemsWithThreadPool[K any](ctx *armadacontext.Context, maxThreadCount int, itemsToProcess []K, processFunc func(K)) { wg := &sync.WaitGroup{} processChannel := make(chan K) @@ -24,7 +24,7 @@ func ProcessItemsWithThreadPool[K any](ctx context.Context, maxThreadCount int, wg.Wait() } -func poolWorker[K any](ctx context.Context, wg *sync.WaitGroup, podsToProcess chan K, processFunc func(K)) { +func poolWorker[K any](ctx *armadacontext.Context, wg *sync.WaitGroup, podsToProcess chan K, processFunc func(K)) { defer wg.Done() for pod := range podsToProcess { diff --git a/internal/executor/util/process_test.go b/internal/executor/util/process_test.go index cfdb237dea9..f6995106c70 100644 --- a/internal/executor/util/process_test.go +++ b/internal/executor/util/process_test.go @@ -1,12 +1,13 @@ package util import ( - "context" "sync" "testing" "time" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) func TestProcessItemsWithThreadPool(t *testing.T) { @@ -14,7 +15,7 @@ func TestProcessItemsWithThreadPool(t *testing.T) { output := []string{} outputMutex := &sync.Mutex{} - ProcessItemsWithThreadPool(context.Background(), 2, input, func(item string) { + ProcessItemsWithThreadPool(armadacontext.Background(), 2, input, func(item string) { outputMutex.Lock() defer outputMutex.Unlock() output = append(output, item) @@ -28,7 +29,7 @@ func TestProcessItemsWithThreadPool_HandlesContextCancellation(t *testing.T) { output := []string{} outputMutex := &sync.Mutex{} - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond*100) defer cancel() ProcessItemsWithThreadPool(ctx, 2, input, func(item string) { diff --git a/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go b/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go index 258d3740942..d7b32d01c7e 100644 --- a/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go +++ b/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go @@ -1,7 +1,6 @@ package utilisation import ( - "context" "sync" "time" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" clusterContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" ) @@ -28,7 +28,7 @@ func (m *podUtilisationKubeletMetrics) fetch(nodes []*v1.Node, podNameToUtilisat wg.Add(1) go func(node *v1.Node) { defer wg.Done() - ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*15) + ctx, cancelFunc := armadacontext.WithTimeout(armadacontext.Background(), time.Second*15) defer cancelFunc() summary, err := clusterContext.GetNodeStatsSummary(ctx, node) if err != nil { diff --git a/internal/lookout/repository/job_pruner.go b/internal/lookout/repository/job_pruner.go index 6c9b92e60df..a77f1657007 100644 --- a/internal/lookout/repository/job_pruner.go +++ b/internal/lookout/repository/job_pruner.go @@ -1,12 +1,13 @@ package repository import ( - "context" "database/sql" "fmt" "time" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const postgresFormat = "2006-01-02 15:04:05.000000" @@ -22,7 +23,7 @@ const postgresFormat = "2006-01-02 15:04:05.000000" // For performance reasons we don't use a transaction here and so an error may indicate that // Some jobs were deleted. func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 120*time.Second) defer cancel() // This would be much better done as a proper statement with parameters, but postgres doesn't support @@ -30,7 +31,7 @@ func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { queryText := fmt.Sprintf(` CREATE TEMP TABLE rows_to_delete AS (SELECT job_id FROM job WHERE submitted < '%v' OR submitted IS NULL); CREATE TEMP TABLE batch (job_id varchar(32)); - + DO $do$ DECLARE @@ -52,7 +53,7 @@ func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { END LOOP; END; $do$; - + DROP TABLE rows_to_delete; DROP TABLE batch; `, cutoff.Format(postgresFormat), batchSizeLimit) diff --git a/internal/lookout/repository/job_sets.go b/internal/lookout/repository/job_sets.go index 70bd187f866..28b60a48179 100644 --- a/internal/lookout/repository/job_sets.go +++ b/internal/lookout/repository/job_sets.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "time" @@ -9,6 +8,7 @@ import ( "github.com/doug-martin/goqu/v9/exp" "github.com/gogo/protobuf/types" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -38,7 +38,7 @@ type jobSetCountsRow struct { QueuedStatsQ3 sql.NullTime `db:"queued_q3"` } -func (r *SQLJobRepository) GetJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) { +func (r *SQLJobRepository) GetJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) { rows, err := r.queryJobSetInfos(ctx, opts) if err != nil { return nil, err @@ -47,7 +47,7 @@ func (r *SQLJobRepository) GetJobSetInfos(ctx context.Context, opts *lookout.Get return r.rowsToJobSets(rows, opts.Queue), nil } -func (r *SQLJobRepository) queryJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*jobSetCountsRow, error) { +func (r *SQLJobRepository) queryJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*jobSetCountsRow, error) { ds := r.createJobSetsDataset(opts) jobsInQueueRows := make([]*jobSetCountsRow, 0) diff --git a/internal/lookout/repository/jobs.go b/internal/lookout/repository/jobs.go index dc03c6d43c4..3d8cb0994c1 100644 --- a/internal/lookout/repository/jobs.go +++ b/internal/lookout/repository/jobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "encoding/json" "errors" "fmt" @@ -13,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/duration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" @@ -20,7 +20,7 @@ import ( "github.com/armadaproject/armada/pkg/api/lookout" ) -func (r *SQLJobRepository) GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) { +func (r *SQLJobRepository) GetJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) { if valid, jobState := validateJobStates(opts.JobStates); !valid { return nil, fmt.Errorf("unknown job state: %q", jobState) } @@ -57,7 +57,7 @@ func isJobState(val string) bool { return false } -func (r *SQLJobRepository) queryJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*JobRow, error) { +func (r *SQLJobRepository) queryJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*JobRow, error) { ds := r.createJobsDataset(opts) jobsInQueueRows := make([]*JobRow, 0) diff --git a/internal/lookout/repository/queues.go b/internal/lookout/repository/queues.go index 32b40aeb3b7..0ad1909e849 100644 --- a/internal/lookout/repository/queues.go +++ b/internal/lookout/repository/queues.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "sort" "time" @@ -10,6 +9,7 @@ import ( "github.com/gogo/protobuf/types" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -25,7 +25,7 @@ type rowsSql struct { LongestRunning string } -func (r *SQLJobRepository) GetQueueInfos(ctx context.Context) ([]*lookout.QueueInfo, error) { +func (r *SQLJobRepository) GetQueueInfos(ctx *armadacontext.Context) ([]*lookout.QueueInfo, error) { queries, err := r.getQueuesSql() if err != nil { return nil, err diff --git a/internal/lookout/repository/sql_repository.go b/internal/lookout/repository/sql_repository.go index 42d72473c5e..af59e92c6ed 100644 --- a/internal/lookout/repository/sql_repository.go +++ b/internal/lookout/repository/sql_repository.go @@ -1,12 +1,12 @@ package repository import ( - "context" "database/sql" "github.com/doug-martin/goqu/v9" _ "github.com/doug-martin/goqu/v9/dialect/postgres" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -33,9 +33,9 @@ const ( ) type JobRepository interface { - GetQueueInfos(ctx context.Context) ([]*lookout.QueueInfo, error) - GetJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) - GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) + GetQueueInfos(ctx *armadacontext.Context) ([]*lookout.QueueInfo, error) + GetJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) + GetJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) } type SQLJobRepository struct { diff --git a/internal/lookout/repository/utils_test.go b/internal/lookout/repository/utils_test.go index 1c073851d0f..54fb40bcc6a 100644 --- a/internal/lookout/repository/utils_test.go +++ b/internal/lookout/repository/utils_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/api/lookout" @@ -29,7 +29,7 @@ var ( node = "node" someTimeUnix = int64(1612546858) someTime = time.Unix(someTimeUnix, 0) - ctx = context.Background() + ctx = armadacontext.Background() ) func AssertJobsAreEquivalent(t *testing.T, expected *api.Job, actual *api.Job) { diff --git a/internal/lookout/server/lookout.go b/internal/lookout/server/lookout.go index df95e7bc2de..cf48fe278aa 100644 --- a/internal/lookout/server/lookout.go +++ b/internal/lookout/server/lookout.go @@ -4,9 +4,11 @@ import ( "context" "github.com/gogo/protobuf/types" + "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/lookout/repository" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -20,7 +22,7 @@ func NewLookoutServer(jobRepository repository.JobRepository) *LookoutServer { } func (s *LookoutServer) Overview(ctx context.Context, _ *types.Empty) (*lookout.SystemOverview, error) { - queues, err := s.jobRepository.GetQueueInfos(ctx) + queues, err := s.jobRepository.GetQueueInfos(armadacontext.New(ctx, logrus.NewEntry(logrus.New()))) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query queue stats: %s", err) } @@ -28,7 +30,7 @@ func (s *LookoutServer) Overview(ctx context.Context, _ *types.Empty) (*lookout. } func (s *LookoutServer) GetJobSets(ctx context.Context, opts *lookout.GetJobSetsRequest) (*lookout.GetJobSetsResponse, error) { - jobSets, err := s.jobRepository.GetJobSetInfos(ctx, opts) + jobSets, err := s.jobRepository.GetJobSetInfos(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), opts) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query queue stats: %s", err) } @@ -36,7 +38,7 @@ func (s *LookoutServer) GetJobSets(ctx context.Context, opts *lookout.GetJobSets } func (s *LookoutServer) GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) (*lookout.GetJobsResponse, error) { - jobInfos, err := s.jobRepository.GetJobs(ctx, opts) + jobInfos, err := s.jobRepository.GetJobs(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), opts) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query jobs in queue: %s", err) } diff --git a/internal/lookout/testutil/db_testutil.go b/internal/lookout/testutil/db_testutil.go index 5ce57e8effa..eaba3992c15 100644 --- a/internal/lookout/testutil/db_testutil.go +++ b/internal/lookout/testutil/db_testutil.go @@ -1,7 +1,6 @@ package testutil import ( - "context" "database/sql" "fmt" @@ -9,6 +8,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookout/repository/schema" ) @@ -61,7 +61,7 @@ func WithDatabase(action func(db *sql.DB) error) error { } func WithDatabasePgx(action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() // Connect and create a dedicated database for the test // For now use database/sql for this diff --git a/internal/lookoutingester/instructions/instructions.go b/internal/lookoutingester/instructions/instructions.go index f49ac049975..2e6b314fc66 100644 --- a/internal/lookoutingester/instructions/instructions.go +++ b/internal/lookoutingester/instructions/instructions.go @@ -1,23 +1,21 @@ package instructions import ( - "context" "sort" "strings" "time" - "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "k8s.io/utils/pointer" - "github.com/armadaproject/armada/internal/common/ingest" - + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" + "github.com/armadaproject/armada/internal/common/ingest" + "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookout/repository" "github.com/armadaproject/armada/internal/lookoutingester/model" @@ -42,7 +40,7 @@ func NewInstructionConverter(metrics *metrics.Metrics, userAnnotationPrefix stri } } -func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { +func (c *InstructionConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { updateInstructions := &model.InstructionSet{ MessageIds: sequencesWithIds.MessageIds, } diff --git a/internal/lookoutingester/instructions/instructions_test.go b/internal/lookoutingester/instructions/instructions_test.go index 5510a30a695..3f4f3043101 100644 --- a/internal/lookoutingester/instructions/instructions_test.go +++ b/internal/lookoutingester/instructions/instructions_test.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "testing" "time" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -339,7 +339,7 @@ var expectedJobRunContainer = model.CreateJobRunContainerInstruction{ func TestSubmit(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(submit) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg.MessageIds, @@ -351,7 +351,7 @@ func TestSubmit(t *testing.T) { func TestDuplicate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(testfixtures.SubmitDuplicate) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ MessageIds: msg.MessageIds, } @@ -364,7 +364,7 @@ func TestDuplicate(t *testing.T) { func TestHappyPathSingleUpdate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(submit, assigned, running, jobRunSucceeded, jobSucceeded) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, JobsToUpdate: []*model.UpdateJobInstruction{&expectedLeased, &expectedRunning, &expectedJobSucceeded}, @@ -384,7 +384,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { svc := SimpleInstructionConverter() // Submit msg1 := NewMsg(submit) - instructions := svc.Convert(context.Background(), msg1) + instructions := svc.Convert(armadacontext.Background(), msg1) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg1.MessageIds, @@ -393,7 +393,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Leased msg2 := NewMsg(assigned) - instructions = svc.Convert(context.Background(), msg2) + instructions = svc.Convert(armadacontext.Background(), msg2) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedLeased}, JobRunsToCreate: []*model.CreateJobRunInstruction{&expectedLeasedRun}, @@ -403,7 +403,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Running msg3 := NewMsg(running) - instructions = svc.Convert(context.Background(), msg3) + instructions = svc.Convert(armadacontext.Background(), msg3) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedRunning}, JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedRunningRun}, @@ -413,7 +413,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Run Succeeded msg4 := NewMsg(jobRunSucceeded) - instructions = svc.Convert(context.Background(), msg4) + instructions = svc.Convert(armadacontext.Background(), msg4) expected = &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedJobRunSucceeded}, MessageIds: msg4.MessageIds, @@ -422,7 +422,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Job Succeeded msg5 := NewMsg(jobSucceeded) - instructions = svc.Convert(context.Background(), msg5) + instructions = svc.Convert(armadacontext.Background(), msg5) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobSucceeded}, MessageIds: msg5.MessageIds, @@ -433,7 +433,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { func TestCancelled(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobCancelled) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobCancelled}, MessageIds: msg.MessageIds, @@ -444,7 +444,7 @@ func TestCancelled(t *testing.T) { func TestReprioritised(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobReprioritised) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobReprioritised}, MessageIds: msg.MessageIds, @@ -455,7 +455,7 @@ func TestReprioritised(t *testing.T) { func TestPreempted(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobPreempted) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedJobRunPreempted}, MessageIds: msg.MessageIds, @@ -466,7 +466,7 @@ func TestPreempted(t *testing.T) { func TestFailed(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobRunFailed) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedFailed}, JobRunContainersToCreate: []*model.CreateJobRunContainerInstruction{&expectedJobRunContainer}, @@ -478,7 +478,7 @@ func TestFailed(t *testing.T) { func TestFailedWithMissingRunId(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobLeaseReturned) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) jobRun := instructions.JobRunsToCreate[0] assert.NotEqual(t, eventutil.LEGACY_RUN_ID, jobRun.RunId) expected := &model.InstructionSet{ @@ -534,7 +534,7 @@ func TestHandlePodTerminated(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(podTerminated) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ MessageIds: msg.MessageIds, } @@ -565,7 +565,7 @@ func TestHandleJobLeaseReturned(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(leaseReturned) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{{ RunId: runIdString, @@ -616,7 +616,7 @@ func TestHandlePodUnschedulable(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(podUnschedulable) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{{ RunId: runIdString, @@ -639,7 +639,7 @@ func TestHandleDuplicate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(duplicate) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{ { @@ -685,7 +685,7 @@ func TestSubmitWithNullChar(t *testing.T) { }) svc := SimpleInstructionConverter() - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) assert.Len(t, instructions.JobsToCreate, 1) assert.NotContains(t, string(instructions.JobsToCreate[0].JobProto), "\\u0000") } @@ -716,7 +716,7 @@ func TestFailedWithNullCharInError(t *testing.T) { }) svc := SimpleInstructionConverter() - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expectedJobRunsToUpdate := []*model.UpdateJobRunInstruction{ { RunId: runIdString, @@ -741,7 +741,7 @@ func TestInvalidEvent(t *testing.T) { // Check that the (valid) Submit is processed, but the invalid message is discarded svc := SimpleInstructionConverter() msg := NewMsg(invalidEvent, submit) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg.MessageIds, diff --git a/internal/lookoutingester/lookoutdb/insertion.go b/internal/lookoutingester/lookoutdb/insertion.go index a22b2eab29b..6009f4cadbc 100644 --- a/internal/lookoutingester/lookoutdb/insertion.go +++ b/internal/lookoutingester/lookoutdb/insertion.go @@ -1,7 +1,6 @@ package lookoutdb import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" @@ -45,7 +45,7 @@ func NewLookoutDb( // * Job Run Updates, New Job Containers // In each case we first try to bach insert the rows using the postgres copy protocol. If this fails then we try a // slower, serial insert and discard any rows that cannot be inserted. -func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSet) error { +func (l *LookoutDb) Store(ctx *armadacontext.Context, instructions *model.InstructionSet) error { jobsToUpdate := instructions.JobsToUpdate jobRunsToUpdate := instructions.JobRunsToUpdate @@ -92,7 +92,7 @@ func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSe return nil } -func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobs(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { if len(instructions) == 0 { return } @@ -109,7 +109,7 @@ func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.Create } } -func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobs(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { if len(instructions) == 0 { return } @@ -127,7 +127,7 @@ func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.Update } } -func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRuns(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { if len(instructions) == 0 { return } @@ -144,7 +144,7 @@ func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.Cre } } -func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRuns(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { if len(instructions) == 0 { return } @@ -161,7 +161,7 @@ func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.Upd } } -func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotations(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { if len(instructions) == 0 { return } @@ -178,7 +178,7 @@ func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*m } } -func (l *LookoutDb) CreateJobRunContainers(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) { +func (l *LookoutDb) CreateJobRunContainers(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) { if len(instructions) == 0 { return } @@ -195,13 +195,13 @@ func (l *LookoutDb) CreateJobRunContainers(ctx context.Context, instructions []* } } -func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.CreateJobInstruction) error { +func (l *LookoutDb) CreateJobsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") createTmp := func(tx pgx.Tx) error { _, err := tx.Exec(ctx, fmt.Sprintf(` - CREATE TEMPORARY TABLE %s + CREATE TEMPORARY TABLE %s ( job_id varchar(32), queue varchar(512), @@ -258,7 +258,7 @@ func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.C } // CreateJobsScalar will insert jobs one by one into the database -func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { sqlStatement := `INSERT INTO job (job_id, queue, owner, jobset, priority, submitted, orig_job_spec, state, job_updated) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT DO NOTHING` @@ -276,7 +276,7 @@ func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.UpdateJobInstruction) error { +func (l *LookoutDb) UpdateJobsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -337,7 +337,7 @@ func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.U }) } -func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { sqlStatement := `UPDATE job SET priority = coalesce($1, priority), @@ -360,7 +360,7 @@ func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*model.CreateJobRunInstruction) error { +func (l *LookoutDb) CreateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -410,7 +410,7 @@ func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { sqlStatement := `INSERT INTO job_run (run_id, job_id, created, cluster) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING` @@ -428,7 +428,7 @@ func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*model.UpdateJobRunInstruction) error { +func (l *LookoutDb) UpdateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -499,7 +499,7 @@ func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { sqlStatement := `UPDATE job_run SET node = coalesce($1, node), @@ -525,7 +525,7 @@ func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) error { +func (l *LookoutDb) CreateUserAnnotationsBatch(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("user_annotation_lookup") @@ -573,9 +573,9 @@ func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions }) } -func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotationsScalar(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { sqlStatement := `INSERT INTO user_annotation_lookup (job_id, key, value) - VALUES ($1, $2, $3) + VALUES ($1, $2, $3) ON CONFLICT DO NOTHING` for _, i := range instructions { err := withDatabaseRetryInsert(func() error { @@ -592,7 +592,7 @@ func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instruction } } -func (l *LookoutDb) CreateJobRunContainersBatch(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) error { +func (l *LookoutDb) CreateJobRunContainersBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run_container") createTmp := func(tx pgx.Tx) error { @@ -641,7 +641,7 @@ func (l *LookoutDb) CreateJobRunContainersBatch(ctx context.Context, instruction }) } -func (l *LookoutDb) CreateJobRunContainersScalar(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) { +func (l *LookoutDb) CreateJobRunContainersScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) { sqlStatement := `INSERT INTO job_run_container (run_id, container_name, exit_code) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING` @@ -659,7 +659,7 @@ func (l *LookoutDb) CreateJobRunContainersScalar(ctx context.Context, instructio } } -func batchInsert(ctx context.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, +func batchInsert(ctx *armadacontext.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, insertTmp func(pgx.Tx) error, copyToDest func(pgx.Tx) error, ) error { return pgx.BeginTxFunc(ctx, db, pgx.TxOptions{ @@ -776,7 +776,7 @@ func conflateJobRunUpdates(updates []*model.UpdateJobRunInstruction) []*model.Up // in the terminal state. If, however, the database returns a non-retryable error it will give up and simply not // filter out any events as the job state is undetermined. func filterEventsForTerminalJobs( - ctx context.Context, + ctx *armadacontext.Context, db *pgxpool.Pool, instructions []*model.UpdateJobInstruction, m *metrics.Metrics, diff --git a/internal/lookoutingester/lookoutdb/insertion_test.go b/internal/lookoutingester/lookoutdb/insertion_test.go index 25a3ff1af03..079912a68c4 100644 --- a/internal/lookoutingester/lookoutdb/insertion_test.go +++ b/internal/lookoutingester/lookoutdb/insertion_test.go @@ -1,21 +1,19 @@ package lookoutdb import ( - "context" "fmt" "sort" "testing" "time" - "github.com/armadaproject/armada/internal/common/database/lookout" - "github.com/apache/pulsar-client-go/pulsar" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/lookout/configuration" "github.com/armadaproject/armada/internal/lookout/repository" @@ -216,24 +214,24 @@ func TestCreateJobsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and test that it's idempotent - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - err = ldb.CreateJobsBatch(context.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + err = ldb.CreateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) assert.Error(t, err) assertNoRows(t, db, "job") return nil @@ -245,29 +243,29 @@ func TestUpdateJobsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Update - err = ldb.UpdateJobsBatch(context.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) require.NoError(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) - err = ldb.UpdateJobsBatch(context.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) require.NoError(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If an update is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - err = ldb.UpdateJobsBatch(context.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + err = ldb.UpdateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) assert.Error(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) @@ -280,28 +278,28 @@ func TestUpdateJobsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Update - ldb.UpdateJobsScalar(context.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // Insert again and test that it's idempotent - ldb.UpdateJobsScalar(context.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If a update is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - ldb.UpdateJobsScalar(context.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + ldb.UpdateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) return nil @@ -376,13 +374,13 @@ func TestUpdateJobsWithTerminal(t *testing.T) { ldb := getTestLookoutDb(db) // Insert - ldb.CreateJobs(context.Background(), initial) + ldb.CreateJobs(armadacontext.Background(), initial) // Mark the jobs terminal - ldb.UpdateJobs(context.Background(), update1) + ldb.UpdateJobs(armadacontext.Background(), update1) // Update the jobs - these should be discarded - ldb.UpdateJobs(context.Background(), update2) + ldb.UpdateJobs(armadacontext.Background(), update2) // Assert the states are still terminal job := getJob(t, db, jobIdString) @@ -403,22 +401,22 @@ func TestCreateJobsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Simple create - ldb.CreateJobsScalar(context.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and check for idempotency - ldb.CreateJobsScalar(context.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should update only the good rows - _, err := ldb.db.Exec(context.Background(), "DELETE FROM job") + _, err := ldb.db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - ldb.CreateJobsScalar(context.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + ldb.CreateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) return nil @@ -430,28 +428,28 @@ func TestCreateJobRunsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(context.Background(), "DELETE FROM job_run") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run") require.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) assert.Error(t, err) assertNoRows(t, db, "job_run") return nil @@ -463,26 +461,26 @@ func TestCreateJobRunsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - ldb.CreateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - ldb.CreateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we create rows that can be created - _, err = db.Exec(context.Background(), "DELETE FROM job_run") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run") require.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - ldb.CreateJobRunsScalar(context.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + ldb.CreateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) return nil @@ -494,33 +492,33 @@ func TestUpdateJobRunsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) // Update - err = ldb.UpdateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - err = ldb.UpdateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job_run;") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run;") require.NoError(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) - err = ldb.UpdateJobRunsBatch(context.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) assert.Error(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, run) @@ -533,33 +531,33 @@ func TestUpdateJobRunsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) // Update - ldb.UpdateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - ldb.UpdateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(context.Background(), "DELETE FROM job_run;") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run;") require.NoError(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - ldb.UpdateJobRunsScalar(context.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + ldb.UpdateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) run = getJobRun(t, ldb.db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) return nil @@ -571,28 +569,28 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - err = ldb.CreateUserAnnotationsBatch(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) require.NoError(t, err) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - err = ldb.CreateUserAnnotationsBatch(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) require.NoError(t, err) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(context.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") require.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - err = ldb.CreateUserAnnotationsBatch(context.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) assert.Error(t, err) assertNoRows(t, ldb.db, "user_annotation_lookup") return nil @@ -603,7 +601,7 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { func TestEmptyUpdate(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) - storeErr := ldb.Store(context.Background(), &model.InstructionSet{}) + storeErr := ldb.Store(armadacontext.Background(), &model.InstructionSet{}) require.NoError(t, storeErr) assertNoRows(t, ldb.db, "job") assertNoRows(t, ldb.db, "job_run") @@ -618,26 +616,26 @@ func TestCreateUserAnnotationsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - ldb.CreateUserAnnotationsScalar(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - ldb.CreateUserAnnotationsScalar(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(context.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") require.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - ldb.CreateUserAnnotationsScalar(context.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) annotation = getUserAnnotationLookup(t, ldb.db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) return nil @@ -649,7 +647,7 @@ func TestUpdate(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Do the update - storeErr := ldb.Store(context.Background(), defaultInstructionSet()) + storeErr := ldb.Store(armadacontext.Background(), defaultInstructionSet()) require.NoError(t, storeErr) job := getJob(t, ldb.db, jobIdString) jobRun := getJobRun(t, ldb.db, runIdString) @@ -748,7 +746,7 @@ func TestConflateJobRunUpdates(T *testing.T) { func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { job := JobRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT job_id, queue, owner, jobset, priority, submitted, state, duplicate, job_updated, orig_job_spec, cancelled FROM job WHERE job_id = $1`, jobId) err := r.Scan( @@ -771,7 +769,7 @@ func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { run := JobRunRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT run_id, job_id, cluster, node, created, started, finished, succeeded, error, pod_number, unable_to_schedule, preempted FROM job_run WHERE run_id = $1`, runId) err := r.Scan( @@ -795,7 +793,7 @@ func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { func getJobRunContainer(t *testing.T, db *pgxpool.Pool, runId string) JobRunContainerRow { container := JobRunContainerRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT run_id, container_name, exit_code FROM job_run_container WHERE run_id = $1`, runId) err := r.Scan(&container.RunId, &container.ContainerName, &container.ExitCode) @@ -806,7 +804,7 @@ func getJobRunContainer(t *testing.T, db *pgxpool.Pool, runId string) JobRunCont func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserAnnotationRow { annotation := UserAnnotationRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT job_id, key, value FROM user_annotation_lookup WHERE job_id = $1`, jobId) err := r.Scan(&annotation.JobId, &annotation.Key, &annotation.Value) @@ -816,7 +814,7 @@ func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserA func assertNoRows(t *testing.T, db *pgxpool.Pool, table string) { var count int - r := db.QueryRow(context.Background(), fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) + r := db.QueryRow(armadacontext.Background(), fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) err := r.Scan(&count) require.NoError(t, err) assert.Equal(t, 0, count) diff --git a/internal/lookoutingesterv2/benchmark/benchmark.go b/internal/lookoutingesterv2/benchmark/benchmark.go index 6c808ca14f3..953ccb85483 100644 --- a/internal/lookoutingesterv2/benchmark/benchmark.go +++ b/internal/lookoutingesterv2/benchmark/benchmark.go @@ -1,7 +1,6 @@ package benchmark import ( - "context" "fmt" "math" "math/rand" @@ -12,6 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutingesterv2/configuration" @@ -51,7 +51,7 @@ func benchmarkSubmissions1000(b *testing.B, config configuration.LookoutIngester withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) b.StartTimer() - err := ldb.Store(context.TODO(), instructions) + err := ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -69,7 +69,7 @@ func benchmarkSubmissions10000(b *testing.B, config configuration.LookoutIngeste withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) b.StartTimer() - err := ldb.Store(context.TODO(), instructions) + err := ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -99,12 +99,12 @@ func benchmarkUpdates1000(b *testing.B, config configuration.LookoutIngesterV2Co withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) - err := ldb.Store(context.TODO(), initialInstructions) + err := ldb.Store(armadacontext.TODO(), initialInstructions) if err != nil { panic(err) } b.StartTimer() - err = ldb.Store(context.TODO(), instructions) + err = ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -134,12 +134,12 @@ func benchmarkUpdates10000(b *testing.B, config configuration.LookoutIngesterV2C withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) - err := ldb.Store(context.TODO(), initialInstructions) + err := ldb.Store(armadacontext.TODO(), initialInstructions) if err != nil { panic(err) } b.StartTimer() - err = ldb.Store(context.TODO(), instructions) + err = ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } diff --git a/internal/lookoutingesterv2/instructions/instructions.go b/internal/lookoutingesterv2/instructions/instructions.go index 49e519f5eb4..25decf00f2a 100644 --- a/internal/lookoutingesterv2/instructions/instructions.go +++ b/internal/lookoutingesterv2/instructions/instructions.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "fmt" "sort" "strings" @@ -14,6 +13,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" @@ -65,7 +65,7 @@ func (c *InstructionConverter) IsLegacy() bool { return c.useLegacyEventConversion } -func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { +func (c *InstructionConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { updateInstructions := &model.InstructionSet{ MessageIds: sequencesWithIds.MessageIds, } @@ -77,7 +77,7 @@ func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *in } func (c *InstructionConverter) convertSequence( - ctx context.Context, + ctx *armadacontext.Context, sequence *armadaevents.EventSequence, update *model.InstructionSet, ) { diff --git a/internal/lookoutingesterv2/instructions/instructions_test.go b/internal/lookoutingesterv2/instructions/instructions_test.go index d70d7d3900d..36e58983283 100644 --- a/internal/lookoutingesterv2/instructions/instructions_test.go +++ b/internal/lookoutingesterv2/instructions/instructions_test.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "fmt" "strings" "testing" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" @@ -560,7 +560,7 @@ func TestConvert(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, tc.useLegacyEventConversion) - instructionSet := converter.Convert(context.TODO(), tc.events) + instructionSet := converter.Convert(armadacontext.TODO(), tc.events) assert.Equal(t, tc.expected.JobsToCreate, instructionSet.JobsToCreate) assert.Equal(t, tc.expected.JobsToUpdate, instructionSet.JobsToUpdate) assert.Equal(t, tc.expected.JobRunsToCreate, instructionSet.JobRunsToCreate) @@ -571,7 +571,7 @@ func TestConvert(t *testing.T) { func TestFailedWithMissingRunId(t *testing.T) { converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) - instructions := converter.Convert(context.Background(), &ingest.EventSequencesWithIds{ + instructions := converter.Convert(armadacontext.Background(), &ingest.EventSequencesWithIds{ EventSequences: []*armadaevents.EventSequence{testfixtures.NewEventSequence(testfixtures.JobLeaseReturned)}, MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, }) @@ -631,7 +631,7 @@ func TestTruncatesStringsThatAreTooLong(t *testing.T) { } converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) - actual := converter.Convert(context.TODO(), events) + actual := converter.Convert(armadacontext.TODO(), events) // String lengths obtained from database schema assert.Len(t, actual.JobsToCreate[0].Queue, 512) diff --git a/internal/lookoutingesterv2/lookoutdb/insertion.go b/internal/lookoutingesterv2/lookoutdb/insertion.go index c5378543df0..2e13c453213 100644 --- a/internal/lookoutingesterv2/lookoutdb/insertion.go +++ b/internal/lookoutingesterv2/lookoutdb/insertion.go @@ -1,7 +1,6 @@ package lookoutdb import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/database/lookout" @@ -36,7 +36,7 @@ func NewLookoutDb(db *pgxpool.Pool, metrics *metrics.Metrics, maxAttempts int, m // * Job Run Updates // In each case we first try to bach insert the rows using the postgres copy protocol. If this fails then we try a // slower, serial insert and discard any rows that cannot be inserted. -func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSet) error { +func (l *LookoutDb) Store(ctx *armadacontext.Context, instructions *model.InstructionSet) error { // We might have multiple updates for the same job or job run // These can be conflated to help performance jobsToUpdate := conflateJobUpdates(instructions.JobsToUpdate) @@ -68,7 +68,7 @@ func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSe return nil } -func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobs(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { if len(instructions) == 0 { return } @@ -79,7 +79,7 @@ func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.Create } } -func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobs(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { if len(instructions) == 0 { return } @@ -91,7 +91,7 @@ func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.Update } } -func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRuns(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { if len(instructions) == 0 { return } @@ -102,7 +102,7 @@ func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.Cre } } -func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRuns(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { if len(instructions) == 0 { return } @@ -113,7 +113,7 @@ func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.Upd } } -func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotations(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { if len(instructions) == 0 { return } @@ -124,7 +124,7 @@ func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*m } } -func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.CreateJobInstruction) error { +func (l *LookoutDb) CreateJobsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -231,7 +231,7 @@ func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.C } // CreateJobsScalar will insert jobs one by one into the database -func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { sqlStatement := `INSERT INTO job ( job_id, queue, @@ -279,7 +279,7 @@ func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.UpdateJobInstruction) error { +func (l *LookoutDb) UpdateJobsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -358,7 +358,7 @@ func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.U }) } -func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { sqlStatement := `UPDATE job SET priority = coalesce($2, priority), @@ -393,7 +393,7 @@ func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*model.CreateJobRunInstruction) error { +func (l *LookoutDb) CreateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -464,7 +464,7 @@ func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { sqlStatement := `INSERT INTO job_run ( run_id, job_id, @@ -496,7 +496,7 @@ func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*model.UpdateJobRunInstruction) error { +func (l *LookoutDb) UpdateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -571,7 +571,7 @@ func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { sqlStatement := `UPDATE job_run SET node = coalesce($2, node), @@ -604,7 +604,7 @@ func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) error { +func (l *LookoutDb) CreateUserAnnotationsBatch(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("user_annotation_lookup") @@ -667,7 +667,7 @@ func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions }) } -func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotationsScalar(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { sqlStatement := `INSERT INTO user_annotation_lookup ( job_id, key, @@ -696,7 +696,7 @@ func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instruction } } -func batchInsert(ctx context.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, +func batchInsert(ctx *armadacontext.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, insertTmp func(pgx.Tx) error, copyToDest func(pgx.Tx) error, ) error { return pgx.BeginTxFunc(ctx, db, pgx.TxOptions{ @@ -834,7 +834,7 @@ type updateInstructionsForJob struct { // in the terminal state. If, however, the database returns a non-retryable error it will give up and simply not // filter out any events as the job state is undetermined. func (l *LookoutDb) filterEventsForTerminalJobs( - ctx context.Context, + ctx *armadacontext.Context, db *pgxpool.Pool, instructions []*model.UpdateJobInstruction, m *metrics.Metrics, diff --git a/internal/lookoutingesterv2/lookoutdb/insertion_test.go b/internal/lookoutingesterv2/lookoutdb/insertion_test.go index 9de584df3fa..13b64c12365 100644 --- a/internal/lookoutingesterv2/lookoutdb/insertion_test.go +++ b/internal/lookoutingesterv2/lookoutdb/insertion_test.go @@ -1,7 +1,6 @@ package lookoutdb import ( - ctx "context" "fmt" "sort" "testing" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/lookoutingesterv2/metrics" @@ -202,24 +202,24 @@ func TestCreateJobsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and test that it's idempotent - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - err = ldb.CreateJobsBatch(ctx.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + err = ldb.CreateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) assert.Error(t, err) assertNoRows(t, db, "job") return nil @@ -231,29 +231,29 @@ func TestUpdateJobsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Update - err = ldb.UpdateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) assert.Nil(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) - err = ldb.UpdateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) assert.Nil(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If an update is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - err = ldb.UpdateJobsBatch(ctx.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + err = ldb.UpdateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) assert.Error(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) @@ -266,28 +266,28 @@ func TestUpdateJobsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Update - ldb.UpdateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // Insert again and test that it's idempotent - ldb.UpdateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If a update is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - ldb.UpdateJobsScalar(ctx.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + ldb.UpdateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) return nil @@ -399,13 +399,13 @@ func TestUpdateJobsWithTerminal(t *testing.T) { ldb := NewLookoutDb(db, m, 2, 10) // Insert - ldb.CreateJobs(ctx.Background(), initial) + ldb.CreateJobs(armadacontext.Background(), initial) // Mark the jobs terminal - ldb.UpdateJobs(ctx.Background(), update1) + ldb.UpdateJobs(armadacontext.Background(), update1) // Update the jobs - these should be discarded - ldb.UpdateJobs(ctx.Background(), update2) + ldb.UpdateJobs(armadacontext.Background(), update2) // Assert the states are still terminal job := getJob(t, db, jobIdString) @@ -427,22 +427,22 @@ func TestCreateJobsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Simple create - ldb.CreateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and check for idempotency - ldb.CreateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should update only the good rows - _, err := ldb.db.Exec(ctx.Background(), "DELETE FROM job") + _, err := ldb.db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - ldb.CreateJobsScalar(ctx.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + ldb.CreateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) return nil @@ -454,28 +454,28 @@ func TestCreateJobRunsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM job_run") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run") assert.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) assert.Error(t, err) assertNoRows(t, db, "job_run") return nil @@ -487,26 +487,26 @@ func TestCreateJobRunsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - ldb.CreateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - ldb.CreateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we create rows that can be created - _, err = db.Exec(ctx.Background(), "DELETE FROM job_run") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run") assert.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - ldb.CreateJobRunsScalar(ctx.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + ldb.CreateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) return nil @@ -518,33 +518,33 @@ func TestUpdateJobRunsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) // Update - err = ldb.UpdateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - err = ldb.UpdateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job_run;") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run;") assert.Nil(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - err = ldb.UpdateJobRunsBatch(ctx.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) assert.Error(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, run) @@ -557,33 +557,33 @@ func TestUpdateJobRunsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) // Update - ldb.UpdateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - ldb.UpdateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM job_run;") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run;") assert.Nil(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - ldb.UpdateJobRunsScalar(ctx.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + ldb.UpdateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) run = getJobRun(t, ldb.db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) return nil @@ -595,28 +595,28 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) assert.Nil(t, err) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) assert.Nil(t, err) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") assert.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) assert.Error(t, err) assertNoRows(t, ldb.db, "user_annotation_lookup") return nil @@ -627,7 +627,7 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { func TestStoreWithEmptyInstructionSet(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) - err := ldb.Store(ctx.Background(), &model.InstructionSet{ + err := ldb.Store(armadacontext.Background(), &model.InstructionSet{ MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, }) assert.NoError(t, err) @@ -643,26 +643,26 @@ func TestCreateUserAnnotationsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - ldb.CreateUserAnnotationsScalar(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - ldb.CreateUserAnnotationsScalar(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") assert.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - ldb.CreateUserAnnotationsScalar(ctx.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) annotation = getUserAnnotationLookup(t, ldb.db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) return nil @@ -674,7 +674,7 @@ func TestStore(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Do the update - err := ldb.Store(ctx.Background(), defaultInstructionSet()) + err := ldb.Store(armadacontext.Background(), defaultInstructionSet()) assert.NoError(t, err) job := getJob(t, ldb.db, jobIdString) @@ -843,7 +843,7 @@ func TestStoreNullValue(t *testing.T) { ldb := NewLookoutDb(db, m, 2, 10) // Do the update - err := ldb.Store(ctx.Background(), instructions) + err := ldb.Store(armadacontext.Background(), instructions) assert.NoError(t, err) job := getJob(t, ldb.db, jobIdString) @@ -875,7 +875,7 @@ func TestStoreEventsForAlreadyTerminalJobs(t *testing.T) { } // Create the jobs in the DB - err := ldb.Store(ctx.Background(), baseInstructions) + err := ldb.Store(armadacontext.Background(), baseInstructions) assert.NoError(t, err) mutateInstructions := &model.InstructionSet{ @@ -895,7 +895,7 @@ func TestStoreEventsForAlreadyTerminalJobs(t *testing.T) { } // Update the jobs in the DB - err = ldb.Store(ctx.Background(), mutateInstructions) + err = ldb.Store(armadacontext.Background(), mutateInstructions) assert.NoError(t, err) for _, jobId := range []string{"job-1", "job-2", "job-3"} { @@ -941,7 +941,7 @@ func makeUpdateJobInstruction(jobId string, state int32) *model.UpdateJobInstruc func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { job := JobRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT job_id, queue, @@ -992,7 +992,7 @@ func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { run := JobRunRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT run_id, job_id, @@ -1025,7 +1025,7 @@ func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserAnnotationRow { annotation := UserAnnotationRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT job_id, key, value, queue, jobset FROM user_annotation_lookup WHERE job_id = $1`, jobId) err := r.Scan(&annotation.JobId, &annotation.Key, &annotation.Value, &annotation.Queue, &annotation.JobSet) @@ -1037,7 +1037,7 @@ func assertNoRows(t *testing.T, db *pgxpool.Pool, table string) { t.Helper() var count int query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) - r := db.QueryRow(ctx.Background(), query) + r := db.QueryRow(armadacontext.Background(), query) err := r.Scan(&count) assert.NoError(t, err) assert.Equal(t, 0, count) diff --git a/internal/lookoutv2/application.go b/internal/lookoutv2/application.go index ca6844f8b32..0b0fa42bb86 100644 --- a/internal/lookoutv2/application.go +++ b/internal/lookoutv2/application.go @@ -3,10 +3,12 @@ package lookoutv2 import ( + "github.com/caarlos0/log" "github.com/go-openapi/loads" "github.com/go-openapi/runtime/middleware" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" @@ -38,6 +40,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { // create new service API api := operations.NewLookoutAPI(swaggerSpec) + logger := logrus.NewEntry(logrus.New()) + api.GetHealthHandler = operations.GetHealthHandlerFunc( func(params operations.GetHealthParams) middleware.Responder { return operations.NewGetHealthOK().WithPayload("Health check passed") @@ -53,7 +57,7 @@ func Serve(configuration configuration.LookoutV2Configuration) error { skip = int(*params.GetJobsRequest.Skip) } result, err := getJobsRepo.GetJobs( - params.HTTPRequest.Context(), + armadacontext.New(params.HTTPRequest.Context(), logger), filters, params.GetJobsRequest.ActiveJobSets, order, @@ -78,7 +82,7 @@ func Serve(configuration configuration.LookoutV2Configuration) error { skip = int(*params.GroupJobsRequest.Skip) } result, err := groupJobsRepo.GroupBy( - params.HTTPRequest.Context(), + armadacontext.New(params.HTTPRequest.Context(), logger), filters, params.GroupJobsRequest.ActiveJobSets, order, @@ -98,7 +102,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { api.GetJobRunErrorHandler = operations.GetJobRunErrorHandlerFunc( func(params operations.GetJobRunErrorParams) middleware.Responder { - result, err := getJobRunErrorRepo.GetJobRunError(params.HTTPRequest.Context(), params.GetJobRunErrorRequest.RunID) + ctx := armadacontext.New(params.HTTPRequest.Context(), logger) + result, err := getJobRunErrorRepo.GetJobRunError(ctx, params.GetJobRunErrorRequest.RunID) if err != nil { return operations.NewGetJobRunErrorBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) } @@ -110,7 +115,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { api.GetJobSpecHandler = operations.GetJobSpecHandlerFunc( func(params operations.GetJobSpecParams) middleware.Responder { - result, err := getJobSpecRepo.GetJobSpec(params.HTTPRequest.Context(), params.GetJobSpecRequest.JobID) + ctx := armadacontext.New(params.HTTPRequest.Context(), logger) + result, err := getJobSpecRepo.GetJobSpec(ctx, params.GetJobSpecRequest.JobID) if err != nil { return operations.NewGetJobSpecBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) } diff --git a/internal/lookoutv2/gen/restapi/doc.go b/internal/lookoutv2/gen/restapi/doc.go index 23beb22a1a0..a8686cb04ea 100644 --- a/internal/lookoutv2/gen/restapi/doc.go +++ b/internal/lookoutv2/gen/restapi/doc.go @@ -2,18 +2,18 @@ // Package restapi Lookout v2 API // -// Schemes: -// http -// Host: localhost -// BasePath: / -// Version: 2.0.0 +// Schemes: +// http +// Host: localhost +// BasePath: / +// Version: 2.0.0 // -// Consumes: -// - application/json +// Consumes: +// - application/json // -// Produces: -// - application/json -// - text/plain +// Produces: +// - application/json +// - text/plain // // swagger:meta package restapi diff --git a/internal/lookoutv2/gen/restapi/operations/get_health.go b/internal/lookoutv2/gen/restapi/operations/get_health.go index 16cd6803823..d7c8a7dc5ad 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_health.go +++ b/internal/lookoutv2/gen/restapi/operations/get_health.go @@ -29,10 +29,10 @@ func NewGetHealth(ctx *middleware.Context, handler GetHealthHandler) *GetHealth return &GetHealth{Context: ctx, Handler: handler} } -/* GetHealth swagger:route GET /health getHealth +/* + GetHealth swagger:route GET /health getHealth GetHealth get health API - */ type GetHealth struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_health_responses.go b/internal/lookoutv2/gen/restapi/operations/get_health_responses.go index 032b8c2cb0d..c54a26244c4 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_health_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_health_responses.go @@ -14,7 +14,8 @@ import ( // GetHealthOKCode is the HTTP code returned for type GetHealthOK const GetHealthOKCode int = 200 -/*GetHealthOK OK +/* +GetHealthOK OK swagger:response getHealthOK */ @@ -56,7 +57,8 @@ func (o *GetHealthOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pro // GetHealthBadRequestCode is the HTTP code returned for type GetHealthBadRequest const GetHealthBadRequestCode int = 400 -/*GetHealthBadRequest Error response +/* +GetHealthBadRequest Error response swagger:response getHealthBadRequest */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go b/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go index 537d2663379..f8add74ee45 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go @@ -34,10 +34,10 @@ func NewGetJobRunError(ctx *middleware.Context, handler GetJobRunErrorHandler) * return &GetJobRunError{Context: ctx, Handler: handler} } -/* GetJobRunError swagger:route POST /api/v1/jobRunError getJobRunError +/* + GetJobRunError swagger:route POST /api/v1/jobRunError getJobRunError GetJobRunError get job run error API - */ type GetJobRunError struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go b/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go index ff1a82e47c2..e8a17e5b37d 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go @@ -16,7 +16,8 @@ import ( // GetJobRunErrorOKCode is the HTTP code returned for type GetJobRunErrorOK const GetJobRunErrorOKCode int = 200 -/*GetJobRunErrorOK Returns error for specific job run (if present) +/* +GetJobRunErrorOK Returns error for specific job run (if present) swagger:response getJobRunErrorOK */ @@ -60,7 +61,8 @@ func (o *GetJobRunErrorOK) WriteResponse(rw http.ResponseWriter, producer runtim // GetJobRunErrorBadRequestCode is the HTTP code returned for type GetJobRunErrorBadRequest const GetJobRunErrorBadRequestCode int = 400 -/*GetJobRunErrorBadRequest Error response +/* +GetJobRunErrorBadRequest Error response swagger:response getJobRunErrorBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobRunErrorBadRequest) WriteResponse(rw http.ResponseWriter, produce } } -/*GetJobRunErrorDefault Error response +/* +GetJobRunErrorDefault Error response swagger:response getJobRunErrorDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_spec.go b/internal/lookoutv2/gen/restapi/operations/get_job_spec.go index a0ee4726d38..74055af08f8 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_spec.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_spec.go @@ -34,10 +34,10 @@ func NewGetJobSpec(ctx *middleware.Context, handler GetJobSpecHandler) *GetJobSp return &GetJobSpec{Context: ctx, Handler: handler} } -/* GetJobSpec swagger:route POST /api/v1/jobSpec getJobSpec +/* + GetJobSpec swagger:route POST /api/v1/jobSpec getJobSpec GetJobSpec get job spec API - */ type GetJobSpec struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go b/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go index ccccd693330..8c4776d0f47 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go @@ -16,7 +16,8 @@ import ( // GetJobSpecOKCode is the HTTP code returned for type GetJobSpecOK const GetJobSpecOKCode int = 200 -/*GetJobSpecOK Returns raw Job spec +/* +GetJobSpecOK Returns raw Job spec swagger:response getJobSpecOK */ @@ -60,7 +61,8 @@ func (o *GetJobSpecOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pr // GetJobSpecBadRequestCode is the HTTP code returned for type GetJobSpecBadRequest const GetJobSpecBadRequestCode int = 400 -/*GetJobSpecBadRequest Error response +/* +GetJobSpecBadRequest Error response swagger:response getJobSpecBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobSpecBadRequest) WriteResponse(rw http.ResponseWriter, producer ru } } -/*GetJobSpecDefault Error response +/* +GetJobSpecDefault Error response swagger:response getJobSpecDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_jobs.go b/internal/lookoutv2/gen/restapi/operations/get_jobs.go index 76689ed77d0..b498f593901 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_jobs.go +++ b/internal/lookoutv2/gen/restapi/operations/get_jobs.go @@ -37,10 +37,10 @@ func NewGetJobs(ctx *middleware.Context, handler GetJobsHandler) *GetJobs { return &GetJobs{Context: ctx, Handler: handler} } -/* GetJobs swagger:route POST /api/v1/jobs getJobs +/* + GetJobs swagger:route POST /api/v1/jobs getJobs GetJobs get jobs API - */ type GetJobs struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go b/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go index 2b1802191f6..5af80b4f316 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go @@ -16,7 +16,8 @@ import ( // GetJobsOKCode is the HTTP code returned for type GetJobsOK const GetJobsOKCode int = 200 -/*GetJobsOK Returns jobs from API +/* +GetJobsOK Returns jobs from API swagger:response getJobsOK */ @@ -60,7 +61,8 @@ func (o *GetJobsOK) WriteResponse(rw http.ResponseWriter, producer runtime.Produ // GetJobsBadRequestCode is the HTTP code returned for type GetJobsBadRequest const GetJobsBadRequestCode int = 400 -/*GetJobsBadRequest Error response +/* +GetJobsBadRequest Error response swagger:response getJobsBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobsBadRequest) WriteResponse(rw http.ResponseWriter, producer runti } } -/*GetJobsDefault Error response +/* +GetJobsDefault Error response swagger:response getJobsDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/group_jobs.go b/internal/lookoutv2/gen/restapi/operations/group_jobs.go index 4225045294b..208d7856c68 100644 --- a/internal/lookoutv2/gen/restapi/operations/group_jobs.go +++ b/internal/lookoutv2/gen/restapi/operations/group_jobs.go @@ -37,10 +37,10 @@ func NewGroupJobs(ctx *middleware.Context, handler GroupJobsHandler) *GroupJobs return &GroupJobs{Context: ctx, Handler: handler} } -/* GroupJobs swagger:route POST /api/v1/jobGroups groupJobs +/* + GroupJobs swagger:route POST /api/v1/jobGroups groupJobs GroupJobs group jobs API - */ type GroupJobs struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go b/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go index ff442c870bc..b34b787fbbf 100644 --- a/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go @@ -16,7 +16,8 @@ import ( // GroupJobsOKCode is the HTTP code returned for type GroupJobsOK const GroupJobsOKCode int = 200 -/*GroupJobsOK Returns job groups from API +/* +GroupJobsOK Returns job groups from API swagger:response groupJobsOK */ @@ -60,7 +61,8 @@ func (o *GroupJobsOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pro // GroupJobsBadRequestCode is the HTTP code returned for type GroupJobsBadRequest const GroupJobsBadRequestCode int = 400 -/*GroupJobsBadRequest Error response +/* +GroupJobsBadRequest Error response swagger:response groupJobsBadRequest */ @@ -101,7 +103,8 @@ func (o *GroupJobsBadRequest) WriteResponse(rw http.ResponseWriter, producer run } } -/*GroupJobsDefault Error response +/* +GroupJobsDefault Error response swagger:response groupJobsDefault */ diff --git a/internal/lookoutv2/pruner/pruner.go b/internal/lookoutv2/pruner/pruner.go index 18ee81c8da1..946917fe30a 100644 --- a/internal/lookoutv2/pruner/pruner.go +++ b/internal/lookoutv2/pruner/pruner.go @@ -1,16 +1,17 @@ package pruner import ( - "context" "time" "github.com/jackc/pgx/v5" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func PruneDb(ctx context.Context, db *pgx.Conn, keepAfterCompletion time.Duration, batchLimit int, clock clock.Clock) error { +func PruneDb(ctx *armadacontext.Context, db *pgx.Conn, keepAfterCompletion time.Duration, batchLimit int, clock clock.Clock) error { now := clock.Now() cutOffTime := now.Add(-keepAfterCompletion) totalJobsToDelete, err := createJobIdsToDeleteTempTable(ctx, db, cutOffTime) @@ -60,10 +61,10 @@ func PruneDb(ctx context.Context, db *pgx.Conn, keepAfterCompletion time.Duratio } // Returns total number of jobs to delete -func createJobIdsToDeleteTempTable(ctx context.Context, db *pgx.Conn, cutOffTime time.Time) (int, error) { +func createJobIdsToDeleteTempTable(ctx *armadacontext.Context, db *pgx.Conn, cutOffTime time.Time) (int, error) { _, err := db.Exec(ctx, ` CREATE TEMP TABLE job_ids_to_delete AS ( - SELECT job_id FROM job + SELECT job_id FROM job WHERE last_transition_time < $1 )`, cutOffTime) if err != nil { @@ -77,7 +78,7 @@ func createJobIdsToDeleteTempTable(ctx context.Context, db *pgx.Conn, cutOffTime return totalJobsToDelete, nil } -func deleteBatch(ctx context.Context, tx pgx.Tx, batchLimit int) (int, error) { +func deleteBatch(ctx *armadacontext.Context, tx pgx.Tx, batchLimit int) (int, error) { _, err := tx.Exec(ctx, "INSERT INTO batch (job_id) SELECT job_id FROM job_ids_to_delete LIMIT $1;", batchLimit) if err != nil { return -1, err diff --git a/internal/lookoutv2/pruner/pruner_test.go b/internal/lookoutv2/pruner/pruner_test.go index a88274c316a..3df18c0cf05 100644 --- a/internal/lookoutv2/pruner/pruner_test.go +++ b/internal/lookoutv2/pruner/pruner_test.go @@ -1,7 +1,6 @@ package pruner import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" @@ -111,7 +111,7 @@ func TestPruneDb(t *testing.T) { converter := instructions.NewInstructionConverter(metrics.Get(), "armadaproject.io/", &compress.NoOpCompressor{}, true) store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Minute) defer cancel() for _, tj := range tc.jobs { runId := uuid.NewString() @@ -156,7 +156,7 @@ func TestPruneDb(t *testing.T) { func selectStringSet(t *testing.T, db *pgxpool.Pool, query string) map[string]bool { t.Helper() - rows, err := db.Query(context.TODO(), query) + rows, err := db.Query(armadacontext.TODO(), query) assert.NoError(t, err) var ss []string for rows.Next() { diff --git a/internal/lookoutv2/repository/getjobrunerror.go b/internal/lookoutv2/repository/getjobrunerror.go index 467da22ec1a..b878c9291fb 100644 --- a/internal/lookoutv2/repository/getjobrunerror.go +++ b/internal/lookoutv2/repository/getjobrunerror.go @@ -1,18 +1,17 @@ package repository import ( - "context" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" ) type GetJobRunErrorRepository interface { - GetJobRunError(ctx context.Context, runId string) (string, error) + GetJobRunError(ctx *armadacontext.Context, runId string) (string, error) } type SqlGetJobRunErrorRepository struct { @@ -27,7 +26,7 @@ func NewSqlGetJobRunErrorRepository(db *pgxpool.Pool, decompressor compress.Deco } } -func (r *SqlGetJobRunErrorRepository) GetJobRunError(ctx context.Context, runId string) (string, error) { +func (r *SqlGetJobRunErrorRepository) GetJobRunError(ctx *armadacontext.Context, runId string) (string, error) { var rawBytes []byte err := r.db.QueryRow(ctx, "SELECT error FROM job_run WHERE run_id = $1 AND error IS NOT NULL", runId).Scan(&rawBytes) if err != nil { diff --git a/internal/lookoutv2/repository/getjobrunerror_test.go b/internal/lookoutv2/repository/getjobrunerror_test.go index 274de5e6d40..4bf2854929d 100644 --- a/internal/lookoutv2/repository/getjobrunerror_test.go +++ b/internal/lookoutv2/repository/getjobrunerror_test.go @@ -1,12 +1,12 @@ package repository import ( - "context" "testing" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutingesterv2/instructions" @@ -34,7 +34,7 @@ func TestGetJobRunError(t *testing.T) { ApiJob() repo := NewSqlGetJobRunErrorRepository(db, &compress.NoOpDecompressor{}) - result, err := repo.GetJobRunError(context.TODO(), runId) + result, err := repo.GetJobRunError(armadacontext.TODO(), runId) assert.NoError(t, err) assert.Equal(t, expected, result) } @@ -46,7 +46,7 @@ func TestGetJobRunError(t *testing.T) { func TestGetJobRunErrorNotFound(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobRunErrorRepository(db, &compress.NoOpDecompressor{}) - _, err := repo.GetJobRunError(context.TODO(), runId) + _, err := repo.GetJobRunError(armadacontext.TODO(), runId) assert.Error(t, err) return nil }) diff --git a/internal/lookoutv2/repository/getjobs.go b/internal/lookoutv2/repository/getjobs.go index eac6cc0aaf5..cce2550d2b1 100644 --- a/internal/lookoutv2/repository/getjobs.go +++ b/internal/lookoutv2/repository/getjobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "fmt" "sort" @@ -12,13 +11,14 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutv2/model" ) type GetJobsRepository interface { - GetJobs(ctx context.Context, filters []*model.Filter, order *model.Order, skip int, take int) (*GetJobsResult, error) + GetJobs(ctx *armadacontext.Context, filters []*model.Filter, order *model.Order, skip int, take int) (*GetJobsResult, error) } type SqlGetJobsRepository struct { @@ -77,7 +77,7 @@ func NewSqlGetJobsRepository(db *pgxpool.Pool) *SqlGetJobsRepository { } } -func (r *SqlGetJobsRepository) GetJobs(ctx context.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, skip int, take int) (*GetJobsResult, error) { +func (r *SqlGetJobsRepository) GetJobs(ctx *armadacontext.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, skip int, take int) (*GetJobsResult, error) { var jobRows []*jobRow var runRows []*runRow var annotationRows []*annotationRow @@ -243,7 +243,7 @@ func getJobRunTime(run *model.Run) (time.Time, error) { return time.Time{}, errors.Errorf("error when getting run time for run with id %s", run.RunId) } -func makeJobRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow, error) { +func makeJobRows(ctx *armadacontext.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow, error) { query := fmt.Sprintf(` SELECT j.job_id, @@ -302,7 +302,7 @@ func makeJobRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow return rows, nil } -func makeRunRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*runRow, error) { +func makeRunRows(ctx *armadacontext.Context, tx pgx.Tx, tmpTableName string) ([]*runRow, error) { query := fmt.Sprintf(` SELECT jr.job_id, @@ -347,7 +347,7 @@ func makeRunRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*runRow return rows, nil } -func makeAnnotationRows(ctx context.Context, tx pgx.Tx, tempTableName string) ([]*annotationRow, error) { +func makeAnnotationRows(ctx *armadacontext.Context, tx pgx.Tx, tempTableName string) ([]*annotationRow, error) { query := fmt.Sprintf(` SELECT ual.job_id, diff --git a/internal/lookoutv2/repository/getjobs_test.go b/internal/lookoutv2/repository/getjobs_test.go index 3c28805c198..d5b45a5cae7 100644 --- a/internal/lookoutv2/repository/getjobs_test.go +++ b/internal/lookoutv2/repository/getjobs_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -11,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" @@ -77,7 +77,7 @@ func TestGetJobsSingle(t *testing.T) { Job() repo := NewSqlGetJobsRepository(db) - result, err := repo.GetJobs(context.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) + result, err := repo.GetJobs(armadacontext.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) assert.NoError(t, err) assert.Len(t, result.Jobs, 1) assert.Equal(t, 1, result.Count) @@ -105,7 +105,7 @@ func TestGetJobsMultipleRuns(t *testing.T) { // Runs should be sorted from oldest -> newest repo := NewSqlGetJobsRepository(db) - result, err := repo.GetJobs(context.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) + result, err := repo.GetJobs(armadacontext.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) assert.NoError(t, err) assert.Len(t, result.Jobs, 1) assert.Equal(t, 1, result.Count) @@ -119,7 +119,7 @@ func TestOrderByUnsupportedField(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -140,7 +140,7 @@ func TestOrderByUnsupportedDirection(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -192,7 +192,7 @@ func TestGetJobsOrderByJobId(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -212,7 +212,7 @@ func TestGetJobsOrderByJobId(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -259,7 +259,7 @@ func TestGetJobsOrderBySubmissionTime(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -279,7 +279,7 @@ func TestGetJobsOrderBySubmissionTime(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -330,7 +330,7 @@ func TestGetJobsOrderByLastTransitionTime(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -350,7 +350,7 @@ func TestGetJobsOrderByLastTransitionTime(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -377,7 +377,7 @@ func TestFilterByUnsupportedField(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "someField", Match: model.MatchExact, @@ -400,7 +400,7 @@ func TestFilterByUnsupportedMatch(t *testing.T) { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobId", Match: model.MatchLessThan, @@ -443,7 +443,7 @@ func TestGetJobsById(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobId", Match: model.MatchExact, @@ -499,7 +499,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchExact, @@ -518,7 +518,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchStartsWith, @@ -542,7 +542,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchContains, @@ -604,7 +604,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchExact, @@ -623,7 +623,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchStartsWith, @@ -647,7 +647,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchContains, @@ -709,7 +709,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchExact, @@ -728,7 +728,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchStartsWith, @@ -752,7 +752,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchContains, @@ -817,7 +817,7 @@ func TestGetJobsByState(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "state", Match: model.MatchExact, @@ -836,7 +836,7 @@ func TestGetJobsByState(t *testing.T) { t.Run("anyOf", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "state", Match: model.MatchAnyOf, @@ -923,7 +923,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "annotation-key-1", Match: model.MatchExact, @@ -943,7 +943,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exact, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -971,7 +971,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("startsWith, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1000,7 +1000,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("contains, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1029,7 +1029,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exists", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1093,7 +1093,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchExact, @@ -1112,7 +1112,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchGreaterThan, @@ -1135,7 +1135,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchLessThan, @@ -1158,7 +1158,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchGreaterThanOrEqualTo, @@ -1182,7 +1182,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchLessThanOrEqualTo, @@ -1246,7 +1246,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchExact, @@ -1265,7 +1265,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchGreaterThan, @@ -1288,7 +1288,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchLessThan, @@ -1311,7 +1311,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchGreaterThanOrEqualTo, @@ -1335,7 +1335,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchLessThanOrEqualTo, @@ -1399,7 +1399,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchExact, @@ -1418,7 +1418,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchGreaterThan, @@ -1441,7 +1441,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchLessThan, @@ -1464,7 +1464,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchGreaterThanOrEqualTo, @@ -1488,7 +1488,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchLessThanOrEqualTo, @@ -1552,7 +1552,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchExact, @@ -1571,7 +1571,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchGreaterThan, @@ -1594,7 +1594,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchLessThan, @@ -1617,7 +1617,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchGreaterThanOrEqualTo, @@ -1641,7 +1641,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchLessThanOrEqualTo, @@ -1705,7 +1705,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchExact, @@ -1724,7 +1724,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchGreaterThan, @@ -1747,7 +1747,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchLessThan, @@ -1770,7 +1770,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchGreaterThanOrEqualTo, @@ -1794,7 +1794,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchLessThanOrEqualTo, @@ -1865,7 +1865,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchExact, @@ -1884,7 +1884,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchStartsWith, @@ -1908,7 +1908,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchContains, @@ -1957,7 +1957,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 3 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1977,7 +1977,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 7 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1997,7 +1997,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 13 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -2057,7 +2057,7 @@ func TestGetJobsComplex(t *testing.T) { skip := 8 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -2121,7 +2121,7 @@ func TestGetJobsActiveJobSet(t *testing.T) { repo := NewSqlGetJobsRepository(db) result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, true, &model.Order{ diff --git a/internal/lookoutv2/repository/getjobspec.go b/internal/lookoutv2/repository/getjobspec.go index 60c6ac41cd1..55799249f35 100644 --- a/internal/lookoutv2/repository/getjobspec.go +++ b/internal/lookoutv2/repository/getjobspec.go @@ -1,20 +1,19 @@ package repository import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/pkg/api" ) type GetJobSpecRepository interface { - GetJobSpec(ctx context.Context, jobId string) (*api.Job, error) + GetJobSpec(ctx *armadacontext.Context, jobId string) (*api.Job, error) } type SqlGetJobSpecRepository struct { @@ -29,7 +28,7 @@ func NewSqlGetJobSpecRepository(db *pgxpool.Pool, decompressor compress.Decompre } } -func (r *SqlGetJobSpecRepository) GetJobSpec(ctx context.Context, jobId string) (*api.Job, error) { +func (r *SqlGetJobSpecRepository) GetJobSpec(ctx *armadacontext.Context, jobId string) (*api.Job, error) { var rawBytes []byte err := r.db.QueryRow(ctx, "SELECT job_spec FROM job WHERE job_id = $1", jobId).Scan(&rawBytes) if err != nil { diff --git a/internal/lookoutv2/repository/getjobspec_test.go b/internal/lookoutv2/repository/getjobspec_test.go index d7e00d83671..b13a897e8c4 100644 --- a/internal/lookoutv2/repository/getjobspec_test.go +++ b/internal/lookoutv2/repository/getjobspec_test.go @@ -1,12 +1,12 @@ package repository import ( - "context" "testing" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutingesterv2/instructions" @@ -42,7 +42,7 @@ func TestGetJobSpec(t *testing.T) { ApiJob() repo := NewSqlGetJobSpecRepository(db, &compress.NoOpDecompressor{}) - result, err := repo.GetJobSpec(context.TODO(), jobId) + result, err := repo.GetJobSpec(armadacontext.TODO(), jobId) assert.NoError(t, err) assertApiJobsEquivalent(t, job, result) return nil @@ -53,7 +53,7 @@ func TestGetJobSpec(t *testing.T) { func TestGetJobSpecError(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobSpecRepository(db, &compress.NoOpDecompressor{}) - _, err := repo.GetJobSpec(context.TODO(), jobId) + _, err := repo.GetJobSpec(armadacontext.TODO(), jobId) assert.Error(t, err) return nil }) diff --git a/internal/lookoutv2/repository/groupjobs.go b/internal/lookoutv2/repository/groupjobs.go index dd80976dcd6..20dcb5adb0a 100644 --- a/internal/lookoutv2/repository/groupjobs.go +++ b/internal/lookoutv2/repository/groupjobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "strings" @@ -9,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" @@ -22,7 +22,7 @@ type GroupByResult struct { type GroupJobsRepository interface { GroupBy( - ctx context.Context, + ctx *armadacontext.Context, filters []*model.Filter, order *model.Order, groupedField string, @@ -47,7 +47,7 @@ func NewSqlGroupJobsRepository(db *pgxpool.Pool) *SqlGroupJobsRepository { } func (r *SqlGroupJobsRepository) GroupBy( - ctx context.Context, + ctx *armadacontext.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, diff --git a/internal/lookoutv2/repository/groupjobs_test.go b/internal/lookoutv2/repository/groupjobs_test.go index 1f255029f8c..b2bd04d5d03 100644 --- a/internal/lookoutv2/repository/groupjobs_test.go +++ b/internal/lookoutv2/repository/groupjobs_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pointer" @@ -39,7 +39,7 @@ func TestGroupByQueue(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -98,7 +98,7 @@ func TestGroupByJobSet(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -165,7 +165,7 @@ func TestGroupByState(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -330,7 +330,7 @@ func TestGroupByWithFilters(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -452,7 +452,7 @@ func TestGroupJobsWithMaxSubmittedTime(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -552,7 +552,7 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -652,7 +652,7 @@ func TestGroupJobsWithAllStateCounts(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -774,7 +774,7 @@ func TestGroupJobsWithFilteredStateCounts(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: stateField, @@ -898,7 +898,7 @@ func TestGroupJobsComplex(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -997,7 +997,7 @@ func TestGroupByAnnotation(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1112,7 +1112,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -1212,7 +1212,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 3 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1242,7 +1242,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 7 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1272,7 +1272,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 13 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1306,7 +1306,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid field", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1325,7 +1325,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("invalid field", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1344,7 +1344,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid annotation", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1364,7 +1364,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid annotation with same name as column", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1427,7 +1427,7 @@ func TestGroupByActiveJobSets(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, true, &model.Order{ diff --git a/internal/lookoutv2/repository/util.go b/internal/lookoutv2/repository/util.go index d250f3844dc..62143df1f37 100644 --- a/internal/lookoutv2/repository/util.go +++ b/internal/lookoutv2/repository/util.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "strings" "time" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -586,8 +586,8 @@ func (js *JobSimulator) Build() *JobSimulator { EventSequences: []*armadaevents.EventSequence{eventSequence}, MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, } - instructionSet := js.converter.Convert(context.TODO(), eventSequenceWithIds) - err := js.store.Store(context.TODO(), instructionSet) + instructionSet := js.converter.Convert(armadacontext.TODO(), eventSequenceWithIds) + err := js.store.Store(armadacontext.TODO(), instructionSet) if err != nil { log.WithError(err).Error("Simulator failed to store job in database") } diff --git a/internal/pulsartest/watch.go b/internal/pulsartest/watch.go index cbe6e5834fa..210916cf7e1 100644 --- a/internal/pulsartest/watch.go +++ b/internal/pulsartest/watch.go @@ -1,13 +1,13 @@ package pulsartest import ( - "context" "fmt" "log" "os" "github.com/sanity-io/litter" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/pulsarutils" ) @@ -17,12 +17,12 @@ func (a *App) Watch() error { defer a.Reader.Close() for a.Reader.HasNext() { - msg, err := a.Reader.Next(context.Background()) + msg, err := a.Reader.Next(armadacontext.Background()) if err != nil { log.Fatal(err) } - ctx := context.Background() + ctx := armadacontext.Background() msgId := pulsarutils.New(msg.ID().LedgerID(), msg.ID().EntryID(), msg.ID().PartitionIdx(), msg.ID().BatchIdx()) diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index 2e6f2779504..a31eba85f5e 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -8,10 +8,10 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/types" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -81,9 +81,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns return errors.WithStack(err) } - ctx := stream.Context() - log := ctxlogrus.Extract(ctx) - log = log.WithField("executor", req.ExecutorId) + ctx := armadacontext.WithLogField(armadacontext.FromGrpcCtx(stream.Context()), "executor", req.ExecutorId) executor := srv.executorFromLeaseRequest(ctx, req) if err := srv.executorRepository.StoreExecutor(ctx, executor); err != nil { @@ -105,7 +103,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns if err != nil { return err } - log.Infof( + ctx.Log.Infof( "executor currently has %d job runs; sending %d cancellations and %d new runs", len(requestRuns), len(runsToCancel), len(newRuns), ) @@ -216,19 +214,19 @@ func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityCl } // ReportEvents publishes all events to Pulsar. The events are compacted for more efficient publishing. -func (srv *ExecutorApi) ReportEvents(ctx context.Context, list *executorapi.EventList) (*types.Empty, error) { +func (srv *ExecutorApi) ReportEvents(grpcCtx context.Context, list *executorapi.EventList) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSizeBytes, schedulers.Pulsar) return &types.Empty{}, err } // executorFromLeaseRequest extracts a schedulerobjects.Executor from the request. -func (srv *ExecutorApi) executorFromLeaseRequest(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { - log := ctxlogrus.Extract(ctx) +func (srv *ExecutorApi) executorFromLeaseRequest(ctx *armadacontext.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { nodes := make([]*schedulerobjects.Node, 0, len(req.Nodes)) now := srv.clock.Now().UTC() for _, nodeInfo := range req.Nodes { if node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, now); err != nil { - logging.WithStacktrace(log, err).Warnf( + logging.WithStacktrace(ctx.Log, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetExecutorId(), ) } else { diff --git a/internal/scheduler/api_test.go b/internal/scheduler/api_test.go index 77a3c52f7a9..f388a5129d4 100644 --- a/internal/scheduler/api_test.go +++ b/internal/scheduler/api_test.go @@ -1,6 +1,7 @@ package scheduler import ( + "context" "testing" "time" @@ -10,10 +11,10 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/context" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/mocks" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -165,7 +166,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) ctrl := gomock.NewController(t) mockPulsarProducer := mocks.NewMockProducer(ctrl) mockJobRepository := schedulermocks.NewMockJobRepository(ctrl) @@ -179,11 +180,11 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { // set up mocks mockStream.EXPECT().Context().Return(ctx).AnyTimes() mockStream.EXPECT().Recv().Return(tc.request, nil).Times(1) - mockExecutorRepository.EXPECT().StoreExecutor(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, executor *schedulerobjects.Executor) error { + mockExecutorRepository.EXPECT().StoreExecutor(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { assert.Equal(t, tc.expectedExecutor, executor) return nil }).Times(1) - mockLegacyExecutorRepository.EXPECT().StoreExecutor(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, executor *schedulerobjects.Executor) error { + mockLegacyExecutorRepository.EXPECT().StoreExecutor(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { assert.Equal(t, tc.expectedExecutor, executor) return nil }).Times(1) @@ -304,7 +305,7 @@ func TestExecutorApi_Publish(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) ctrl := gomock.NewController(t) mockPulsarProducer := mocks.NewMockProducer(ctrl) mockJobRepository := schedulermocks.NewMockJobRepository(ctrl) diff --git a/internal/scheduler/database/db.go b/internal/scheduler/database/db.go index 5af3de156f4..8f9fc5e6de2 100644 --- a/internal/scheduler/database/db.go +++ b/internal/scheduler/database/db.go @@ -7,8 +7,8 @@ package database import ( "context" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type DBTX interface { diff --git a/internal/scheduler/database/db_pruner.go b/internal/scheduler/database/db_pruner.go index 728c3c9b71b..9ea8075a40d 100644 --- a/internal/scheduler/database/db_pruner.go +++ b/internal/scheduler/database/db_pruner.go @@ -28,7 +28,7 @@ func PruneDb(ctx ctx.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion // Insert the ids of all jobs we want to delete into a tmp table _, err = db.Exec(ctx, `CREATE TEMP TABLE rows_to_delete AS ( - SELECT job_id FROM jobs + SELECT job_id FROM jobs WHERE last_modified < $1 AND (succeeded = TRUE OR failed = TRUE OR cancelled = TRUE))`, cutOffTime) if err != nil { diff --git a/internal/scheduler/database/db_pruner_test.go b/internal/scheduler/database/db_pruner_test.go index bd1165ed2d3..1a30c200463 100644 --- a/internal/scheduler/database/db_pruner_test.go +++ b/internal/scheduler/database/db_pruner_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" commonutil "github.com/armadaproject/armada/internal/common/util" ) @@ -108,7 +108,7 @@ func TestPruneDb_RemoveJobs(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := WithTestDb(func(_ *Queries, db *pgxpool.Pool) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() testClock := clock.NewFakeClock(baseTime) @@ -186,7 +186,7 @@ func TestPruneDb_RemoveMarkers(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := WithTestDb(func(_ *Queries, db *pgxpool.Pool) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() testClock := clock.NewFakeClock(baseTime) @@ -220,7 +220,7 @@ func TestPruneDb_RemoveMarkers(t *testing.T) { // Removes the triggers that auto-set serial and last_update_time as // we need to manipulate these as part of the test -func removeTriggers(ctx context.Context, db *pgxpool.Pool) error { +func removeTriggers(ctx *armadacontext.Context, db *pgxpool.Pool) error { triggers := map[string]string{ "jobs": "next_serial_on_insert_jobs", "runs": "next_serial_on_insert_runs", diff --git a/internal/scheduler/database/executor_repository.go b/internal/scheduler/database/executor_repository.go index c2da2442e54..ec50db20126 100644 --- a/internal/scheduler/database/executor_repository.go +++ b/internal/scheduler/database/executor_repository.go @@ -1,13 +1,13 @@ package database import ( - "context" "time" "github.com/gogo/protobuf/proto" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -15,11 +15,11 @@ import ( // ExecutorRepository is an interface to be implemented by structs which provide executor information type ExecutorRepository interface { // GetExecutors returns all known executors, regardless of their last heartbeat time - GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) + GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) // GetLastUpdateTimes returns a map of executor name -> last heartbeat time - GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) + GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) // StoreExecutor persists the latest executor state - StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error + StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error } // PostgresExecutorRepository is an implementation of ExecutorRepository that stores its state in postgres @@ -40,7 +40,7 @@ func NewPostgresExecutorRepository(db *pgxpool.Pool) *PostgresExecutorRepository } // GetExecutors returns all known executors, regardless of their last heartbeat time -func (r *PostgresExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (r *PostgresExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { queries := New(r.db) requests, err := queries.SelectAllExecutors(ctx) if err != nil { @@ -59,7 +59,7 @@ func (r *PostgresExecutorRepository) GetExecutors(ctx context.Context) ([]*sched } // GetLastUpdateTimes returns a map of executor name -> last heartbeat time -func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { queries := New(r.db) rows, err := queries.SelectExecutorUpdateTimes(ctx) if err != nil { @@ -74,7 +74,7 @@ func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx context.Context) (ma } // StoreExecutor persists the latest executor state -func (r *PostgresExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (r *PostgresExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { queries := New(r.db) bytes, err := proto.Marshal(executor) if err != nil { diff --git a/internal/scheduler/database/executor_repository_test.go b/internal/scheduler/database/executor_repository_test.go index 2d7bd206512..76a0e14c9f9 100644 --- a/internal/scheduler/database/executor_repository_test.go +++ b/internal/scheduler/database/executor_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -53,7 +53,7 @@ func TestExecutorRepository_LoadAndSave(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withExecutorRepository(func(repo *PostgresExecutorRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) @@ -106,7 +106,7 @@ func TestExecutorRepository_GetLastUpdateTimes(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withExecutorRepository(func(repo *PostgresExecutorRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) diff --git a/internal/scheduler/database/job_repository.go b/internal/scheduler/database/job_repository.go index c4eaf606099..ebc08d03230 100644 --- a/internal/scheduler/database/job_repository.go +++ b/internal/scheduler/database/job_repository.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "github.com/google/uuid" @@ -9,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" protoutil "github.com/armadaproject/armada/internal/common/proto" @@ -35,24 +35,24 @@ type JobRunLease struct { type JobRepository interface { // FetchJobUpdates returns all jobs and job dbRuns that have been updated after jobSerial and jobRunSerial respectively // These updates are guaranteed to be consistent with each other - FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) + FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) // FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is // keyed by job run id. Any dbRuns which don't have errors wil be absent from the map. - FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) + FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) // CountReceivedPartitions returns a count of the number of partition messages present in the database corresponding // to the provided groupId. This is used by the scheduler to determine if the database represents the state of // pulsar after a given point in time. - CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) + CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) // FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active // Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled - FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) + FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded - FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) + FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) } // PostgresJobRepository is an implementation of JobRepository that stores its state in postgres @@ -72,7 +72,7 @@ func NewPostgresJobRepository(db *pgxpool.Pool, batchSize int32) *PostgresJobRep // FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is // keyed by job run id. Any dbRuns which don't have errors wil be absent from the map. -func (r *PostgresJobRepository) FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (r *PostgresJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { if len(runIds) == 0 { return map[uuid.UUID]*armadaevents.Error{}, nil } @@ -125,7 +125,7 @@ func (r *PostgresJobRepository) FetchJobRunErrors(ctx context.Context, runIds [] // FetchJobUpdates returns all jobs and job dbRuns that have been updated after jobSerial and jobRunSerial respectively // These updates are guaranteed to be consistent with each other -func (r *PostgresJobRepository) FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) { +func (r *PostgresJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) { var updatedJobs []Job = nil var updatedRuns []Run = nil @@ -180,7 +180,7 @@ func (r *PostgresJobRepository) FetchJobUpdates(ctx context.Context, jobSerial i // FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active // Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled -func (r *PostgresJobRepository) FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { +func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { var inactiveRuns []uuid.UUID err := pgx.BeginTxFunc(ctx, r.db, pgx.TxOptions{ IsoLevel: pgx.ReadCommitted, @@ -221,7 +221,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx context.Context, runIds []u // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded -func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { +func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { if maxResults == 0 { return []*JobRunLease{}, nil } @@ -272,7 +272,7 @@ func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor // CountReceivedPartitions returns a count of the number of partition messages present in the database corresponding // to the provided groupId. This is used by the scheduler to determine if the database represents the state of // pulsar after a given point in time. -func (r *PostgresJobRepository) CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (r *PostgresJobRepository) CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { queries := New(r.db) count, err := queries.CountGroup(ctx, groupId) if err != nil { @@ -300,7 +300,7 @@ func fetch[T hasSerial](from int64, batchSize int32, fetchBatch func(int64) ([]T } // Insert all run ids into a tmp table. The name of the table is returned -func insertRunIdsToTmpTable(ctx context.Context, tx pgx.Tx, runIds []uuid.UUID) (string, error) { +func insertRunIdsToTmpTable(ctx *armadacontext.Context, tx pgx.Tx, runIds []uuid.UUID) (string, error) { tmpTable := database.UniqueTableName("job_runs") _, err := tx.Exec(ctx, fmt.Sprintf("CREATE TEMPORARY TABLE %s (run_id uuid) ON COMMIT DROP", tmpTable)) diff --git a/internal/scheduler/database/job_repository_test.go b/internal/scheduler/database/job_repository_test.go index b236618185b..771d887b17d 100644 --- a/internal/scheduler/database/job_repository_test.go +++ b/internal/scheduler/database/job_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -13,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" protoutil "github.com/armadaproject/armada/internal/common/proto" @@ -84,7 +84,7 @@ func TestFetchJobUpdates(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "jobs", tc.dbJobs) @@ -187,7 +187,7 @@ func TestFetchJobRunErrors(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "job_run_errors", tc.errorsInDb) require.NoError(t, err) @@ -222,7 +222,7 @@ func TestCountReceivedPartitions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) markers := make([]Marker, tc.numPartitions) groupId := uuid.New() @@ -357,7 +357,7 @@ func TestFindInactiveRuns(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 500*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "runs", tc.dbRuns) @@ -487,7 +487,7 @@ func TestFetchJobRunLeases(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "jobs", tc.dbJobs) @@ -553,7 +553,7 @@ func withJobRepository(action func(repository *PostgresJobRepository) error) err }) } -func insertMarkers(ctx context.Context, markers []Marker, db *pgxpool.Pool) error { +func insertMarkers(ctx *armadacontext.Context, markers []Marker, db *pgxpool.Pool) error { for _, marker := range markers { _, err := db.Exec(ctx, "INSERT INTO markers VALUES ($1, $2)", marker.GroupID, marker.PartitionID) if err != nil { diff --git a/internal/scheduler/database/redis_executor_repository.go b/internal/scheduler/database/redis_executor_repository.go index 989710a69da..ef775ff7f75 100644 --- a/internal/scheduler/database/redis_executor_repository.go +++ b/internal/scheduler/database/redis_executor_repository.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "time" @@ -9,6 +8,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -29,7 +29,7 @@ func NewRedisExecutorRepository(db redis.UniversalClient, schedulerName string) } } -func (r *RedisExecutorRepository) GetExecutors(_ context.Context) ([]*schedulerobjects.Executor, error) { +func (r *RedisExecutorRepository) GetExecutors(_ *armadacontext.Context) ([]*schedulerobjects.Executor, error) { result, err := r.db.HGetAll(r.executorsKey).Result() if err != nil { return nil, errors.Wrap(err, "Error retrieving executors from redis") @@ -47,12 +47,12 @@ func (r *RedisExecutorRepository) GetExecutors(_ context.Context) ([]*schedulero return executors, nil } -func (r *RedisExecutorRepository) GetLastUpdateTimes(_ context.Context) (map[string]time.Time, error) { +func (r *RedisExecutorRepository) GetLastUpdateTimes(_ *armadacontext.Context) (map[string]time.Time, error) { // We could implement this in a very inefficient way, but I don't believe it's needed so panic for now panic("GetLastUpdateTimes is not implemented") } -func (r *RedisExecutorRepository) StoreExecutor(_ context.Context, executor *schedulerobjects.Executor) error { +func (r *RedisExecutorRepository) StoreExecutor(_ *armadacontext.Context, executor *schedulerobjects.Executor) error { data, err := proto.Marshal(executor) if err != nil { return errors.Wrap(err, "Error marshalling executor proto") diff --git a/internal/scheduler/database/redis_executor_repository_test.go b/internal/scheduler/database/redis_executor_repository_test.go index 6fb48d66c49..bf5b0ea9629 100644 --- a/internal/scheduler/database/redis_executor_repository_test.go +++ b/internal/scheduler/database/redis_executor_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -53,7 +53,7 @@ func TestRedisExecutorRepository_LoadAndSave(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { withRedisExecutorRepository(func(repo *RedisExecutorRepository) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) diff --git a/internal/scheduler/database/util.go b/internal/scheduler/database/util.go index d6539a2a743..618c32c8efb 100644 --- a/internal/scheduler/database/util.go +++ b/internal/scheduler/database/util.go @@ -1,7 +1,6 @@ package database import ( - "context" "embed" _ "embed" "time" @@ -9,13 +8,14 @@ import ( "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" ) //go:embed migrations/*.sql var fs embed.FS -func Migrate(ctx context.Context, db database.Querier) error { +func Migrate(ctx *armadacontext.Context, db database.Querier) error { start := time.Now() migrations, err := database.ReadMigrations(fs, "migrations") if err != nil { diff --git a/internal/scheduler/gang_scheduler.go b/internal/scheduler/gang_scheduler.go index ffca7be9f8e..fb9a3add118 100644 --- a/internal/scheduler/gang_scheduler.go +++ b/internal/scheduler/gang_scheduler.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "fmt" "github.com/hashicorp/go-memdb" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -38,7 +38,7 @@ func (sch *GangScheduler) SkipUnsuccessfulSchedulingKeyCheck() { sch.skipUnsuccessfulSchedulingKeyCheck = true } -func (sch *GangScheduler) Schedule(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) Schedule(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { // Exit immediately if this is a new gang and we've exceeded any round limits. // // Because this check occurs before adding the gctx to the sctx, @@ -109,7 +109,7 @@ func (sch *GangScheduler) Schedule(ctx context.Context, gctx *schedulercontext.G return sch.trySchedule(ctx, gctx) } -func (sch *GangScheduler) trySchedule(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) trySchedule(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { // If no node uniformity constraint, try scheduling across all nodes. if gctx.NodeUniformityLabel == "" { return sch.tryScheduleGang(ctx, gctx) @@ -176,7 +176,7 @@ func (sch *GangScheduler) trySchedule(ctx context.Context, gctx *schedulercontex return sch.tryScheduleGang(ctx, gctx) } -func (sch *GangScheduler) tryScheduleGang(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) tryScheduleGang(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { txn := sch.nodeDb.Txn(true) defer txn.Abort() ok, unschedulableReason, err = sch.tryScheduleGangWithTxn(ctx, txn, gctx) @@ -186,7 +186,7 @@ func (sch *GangScheduler) tryScheduleGang(ctx context.Context, gctx *schedulerco return } -func (sch *GangScheduler) tryScheduleGangWithTxn(ctx context.Context, txn *memdb.Txn, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) tryScheduleGangWithTxn(ctx *armadacontext.Context, txn *memdb.Txn, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { if ok, err = sch.nodeDb.ScheduleManyWithTxn(txn, gctx.JobSchedulingContexts); err != nil { return } else if !ok { diff --git a/internal/scheduler/gang_scheduler_test.go b/internal/scheduler/gang_scheduler_test.go index e5fbafad703..cc79703d2b2 100644 --- a/internal/scheduler/gang_scheduler_test.go +++ b/internal/scheduler/gang_scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -372,7 +372,7 @@ func TestGangScheduler(t *testing.T) { for i, gang := range tc.Gangs { jctxs := schedulercontext.JobSchedulingContextsFromJobs(testfixtures.TestPriorityClasses, gang) gctx := schedulercontext.NewGangSchedulingContext(jctxs) - ok, reason, err := sch.Schedule(context.Background(), gctx) + ok, reason, err := sch.Schedule(armadacontext.Background(), gctx) require.NoError(t, err) if ok { require.Empty(t, reason) diff --git a/internal/scheduler/jobiteration.go b/internal/scheduler/jobiteration.go index 7b232edc141..04dd63a6490 100644 --- a/internal/scheduler/jobiteration.go +++ b/internal/scheduler/jobiteration.go @@ -1,13 +1,12 @@ package scheduler import ( - "context" "sync" "golang.org/x/exp/maps" "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/internal/scheduler/interfaces" ) @@ -136,7 +135,7 @@ func (repo *InMemoryJobRepository) GetExistingJobsByIds(jobIds []string) ([]inte return rv, nil } -func (repo *InMemoryJobRepository) GetJobIterator(ctx context.Context, queue string) (JobIterator, error) { +func (repo *InMemoryJobRepository) GetJobIterator(ctx *armadacontext.Context, queue string) (JobIterator, error) { repo.mu.Lock() defer repo.mu.Unlock() return NewInMemoryJobIterator(slices.Clone(repo.jobsByQueue[queue])), nil @@ -145,14 +144,14 @@ func (repo *InMemoryJobRepository) GetJobIterator(ctx context.Context, queue str // QueuedJobsIterator is an iterator over all jobs in a queue. // It lazily loads jobs in batches from Redis asynch. type QueuedJobsIterator struct { - ctx context.Context + ctx *armadacontext.Context err error c chan interfaces.LegacySchedulerJob } -func NewQueuedJobsIterator(ctx context.Context, queue string, repo JobRepository) (*QueuedJobsIterator, error) { +func NewQueuedJobsIterator(ctx *armadacontext.Context, queue string, repo JobRepository) (*QueuedJobsIterator, error) { batchSize := 16 - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) it := &QueuedJobsIterator{ ctx: ctx, c: make(chan interfaces.LegacySchedulerJob, 2*batchSize), // 2x batchSize to load one batch async. @@ -190,7 +189,7 @@ func (it *QueuedJobsIterator) Next() (interfaces.LegacySchedulerJob, error) { // queuedJobsIteratorLoader loads jobs from Redis lazily. // Used with QueuedJobsIterator. -func queuedJobsIteratorLoader(ctx context.Context, jobIds []string, ch chan interfaces.LegacySchedulerJob, batchSize int, repo JobRepository) error { +func queuedJobsIteratorLoader(ctx *armadacontext.Context, jobIds []string, ch chan interfaces.LegacySchedulerJob, batchSize int, repo JobRepository) error { defer close(ch) batch := make([]string, batchSize) for i, jobId := range jobIds { diff --git a/internal/scheduler/jobiteration_test.go b/internal/scheduler/jobiteration_test.go index 42133f0ba05..a5990fa3fc4 100644 --- a/internal/scheduler/jobiteration_test.go +++ b/internal/scheduler/jobiteration_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/interfaces" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" @@ -87,7 +88,7 @@ func TestMultiJobsIterator_TwoQueues(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() its := make([]JobIterator, 3) for i, queue := range []string{"A", "B", "C"} { it, err := NewQueuedJobsIterator(ctx, queue, repo) @@ -121,7 +122,7 @@ func TestQueuedJobsIterator_OneQueue(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -146,7 +147,7 @@ func TestQueuedJobsIterator_ExceedsBufferSize(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -171,7 +172,7 @@ func TestQueuedJobsIterator_ManyJobs(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -200,7 +201,7 @@ func TestCreateQueuedJobsIterator_TwoQueues(t *testing.T) { repo.Enqueue(job) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -223,7 +224,7 @@ func TestCreateQueuedJobsIterator_RespectsTimeout(t *testing.T) { repo.Enqueue(job) } - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond) time.Sleep(20 * time.Millisecond) defer cancel() it, err := NewQueuedJobsIterator(ctx, "A", repo) @@ -248,7 +249,7 @@ func TestCreateQueuedJobsIterator_NilOnEmpty(t *testing.T) { repo.Enqueue(job) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -291,7 +292,7 @@ func (repo *mockJobRepository) Enqueue(job *api.Job) { repo.jobsById[job.Id] = job } -func (repo *mockJobRepository) GetJobIterator(ctx context.Context, queue string) (JobIterator, error) { +func (repo *mockJobRepository) GetJobIterator(ctx *armadacontext.Context, queue string) (JobIterator, error) { return NewQueuedJobsIterator(ctx, queue, repo) } diff --git a/internal/scheduler/leader.go b/internal/scheduler/leader.go index a0c8b8a85f6..0482184a7a8 100644 --- a/internal/scheduler/leader.go +++ b/internal/scheduler/leader.go @@ -6,12 +6,12 @@ import ( "sync/atomic" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" coordinationv1client "k8s.io/client-go/kubernetes/typed/coordination/v1" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "github.com/armadaproject/armada/internal/common/armadacontext" schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" ) @@ -23,7 +23,7 @@ type LeaderController interface { // Returns true if the token is a leader and false otherwise ValidateToken(tok LeaderToken) bool // Run starts the controller. This is a blocking call which will return when the provided context is cancelled - Run(ctx context.Context) error + Run(ctx *armadacontext.Context) error // GetLeaderReport returns a report about the current leader GetLeaderReport() LeaderReport } @@ -85,14 +85,14 @@ func (lc *StandaloneLeaderController) ValidateToken(tok LeaderToken) bool { return false } -func (lc *StandaloneLeaderController) Run(ctx context.Context) error { +func (lc *StandaloneLeaderController) Run(ctx *armadacontext.Context) error { return nil } // LeaseListener allows clients to listen for lease events. type LeaseListener interface { // Called when the client has started leading. - onStartedLeading(context.Context) + onStartedLeading(*armadacontext.Context) // Called when the client has stopped leading, onStoppedLeading() } @@ -138,16 +138,14 @@ func (lc *KubernetesLeaderController) ValidateToken(tok LeaderToken) bool { // Run starts the controller. // This is a blocking call that returns when the provided context is cancelled. -func (lc *KubernetesLeaderController) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "KubernetesLeaderController") +func (lc *KubernetesLeaderController) Run(ctx *armadacontext.Context) error { for { select { case <-ctx.Done(): return ctx.Err() default: lock := lc.getNewLock() - log.Infof("attempting to become leader") + ctx.Log.Infof("attempting to become leader") leaderelection.RunOrDie(ctx, leaderelection.LeaderElectionConfig{ Lock: lock, ReleaseOnCancel: true, @@ -156,14 +154,14 @@ func (lc *KubernetesLeaderController) Run(ctx context.Context) error { RetryPeriod: lc.config.RetryPeriod, Callbacks: leaderelection.LeaderCallbacks{ OnStartedLeading: func(c context.Context) { - log.Infof("I am now leader") + ctx.Log.Infof("I am now leader") lc.token.Store(NewLeaderToken()) for _, listener := range lc.listeners { listener.onStartedLeading(ctx) } }, OnStoppedLeading: func() { - log.Infof("I am no longer leader") + ctx.Log.Infof("I am no longer leader") lc.token.Store(InvalidLeaderToken()) for _, listener := range lc.listeners { listener.onStoppedLeading() @@ -176,7 +174,7 @@ func (lc *KubernetesLeaderController) Run(ctx context.Context) error { }, }, }) - log.Infof("leader election round finished") + ctx.Log.Infof("leader election round finished") } } } diff --git a/internal/scheduler/leader_client_test.go b/internal/scheduler/leader_client_test.go index e8909356402..31ba46a8913 100644 --- a/internal/scheduler/leader_client_test.go +++ b/internal/scheduler/leader_client_test.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "testing" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/configuration" "github.com/armadaproject/armada/pkg/client" ) @@ -91,7 +91,7 @@ func (f *FakeLeaderController) ValidateToken(tok LeaderToken) bool { return f.IsCurrentlyLeader } -func (f *FakeLeaderController) Run(ctx context.Context) error { +func (f *FakeLeaderController) Run(_ *armadacontext.Context) error { return nil } diff --git a/internal/scheduler/leader_metrics.go b/internal/scheduler/leader_metrics.go index cc02157504e..d5d4e62f535 100644 --- a/internal/scheduler/leader_metrics.go +++ b/internal/scheduler/leader_metrics.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "sync" "github.com/prometheus/client_golang/prometheus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/metrics" ) @@ -29,7 +29,7 @@ func NewLeaderStatusMetricsCollector(currentInstanceName string) *LeaderStatusMe } } -func (l *LeaderStatusMetricsCollector) onStartedLeading(context.Context) { +func (l *LeaderStatusMetricsCollector) onStartedLeading(*armadacontext.Context) { l.lock.Lock() defer l.lock.Unlock() diff --git a/internal/scheduler/leader_metrics_test.go b/internal/scheduler/leader_metrics_test.go index fec5d4e5d08..8132179afbd 100644 --- a/internal/scheduler/leader_metrics_test.go +++ b/internal/scheduler/leader_metrics_test.go @@ -1,11 +1,12 @@ package scheduler import ( - "context" "testing" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const testInstanceName = "instance-1" @@ -31,7 +32,7 @@ func TestLeaderStatusMetrics_HandlesLeaderChanges(t *testing.T) { assert.Equal(t, actual[0], isNotLeaderMetric) // start leading - collector.onStartedLeading(context.Background()) + collector.onStartedLeading(armadacontext.Background()) actual = getCurrentMetrics(collector) assert.Len(t, actual, 1) assert.Equal(t, actual[0], isLeaderMetric) diff --git a/internal/scheduler/leader_proxying_reports_server_test.go b/internal/scheduler/leader_proxying_reports_server_test.go index 5fc1874d210..2b83a02da28 100644 --- a/internal/scheduler/leader_proxying_reports_server_test.go +++ b/internal/scheduler/leader_proxying_reports_server_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -48,7 +49,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetJobReports(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) @@ -113,7 +114,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetSchedulingReport(t *testing.T) } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) @@ -178,7 +179,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetQueueReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) diff --git a/internal/scheduler/leader_test.go b/internal/scheduler/leader_test.go index 1790c9518b5..17fb468b0cf 100644 --- a/internal/scheduler/leader_test.go +++ b/internal/scheduler/leader_test.go @@ -12,6 +12,7 @@ import ( v1 "k8s.io/api/coordination/v1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" ) @@ -108,7 +109,7 @@ func TestK8sLeaderController_BecomingLeader(t *testing.T) { controller := NewKubernetesLeaderController(testLeaderConfig(), client) testListener := NewTestLeaseListener(controller) controller.RegisterListener(testListener) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) go func() { err := controller.Run(ctx) assert.ErrorIs(t, err, context.Canceled) @@ -184,7 +185,7 @@ func (t *TestLeaseListener) GetMessages() []LeaderToken { return append([]LeaderToken(nil), t.tokens...) } -func (t *TestLeaseListener) onStartedLeading(_ context.Context) { +func (t *TestLeaseListener) onStartedLeading(_ *armadacontext.Context) { t.handleNewToken() } diff --git a/internal/scheduler/metrics.go b/internal/scheduler/metrics.go index 15da0d6c478..a7fb2f08c78 100644 --- a/internal/scheduler/metrics.go +++ b/internal/scheduler/metrics.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "strings" "sync/atomic" "time" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/scheduler/database" @@ -76,7 +76,7 @@ func NewMetricsCollector( } // Run enters s a loop which updates the metrics every refreshPeriod until the supplied context is cancelled -func (c *MetricsCollector) Run(ctx context.Context) error { +func (c *MetricsCollector) Run(ctx *armadacontext.Context) error { ticker := c.clock.NewTicker(c.refreshPeriod) log.Infof("Will update metrics every %s", c.refreshPeriod) for { @@ -108,7 +108,7 @@ func (c *MetricsCollector) Collect(metrics chan<- prometheus.Metric) { } } -func (c *MetricsCollector) refresh(ctx context.Context) error { +func (c *MetricsCollector) refresh(ctx *armadacontext.Context) error { log.Debugf("Refreshing prometheus metrics") start := time.Now() queueMetrics, err := c.updateQueueMetrics(ctx) @@ -125,7 +125,7 @@ func (c *MetricsCollector) refresh(ctx context.Context) error { return nil } -func (c *MetricsCollector) updateQueueMetrics(ctx context.Context) ([]prometheus.Metric, error) { +func (c *MetricsCollector) updateQueueMetrics(ctx *armadacontext.Context) ([]prometheus.Metric, error) { queues, err := c.queueRepository.GetAllQueues() if err != nil { return nil, err @@ -212,7 +212,7 @@ type clusterMetricKey struct { nodeType string } -func (c *MetricsCollector) updateClusterMetrics(ctx context.Context) ([]prometheus.Metric, error) { +func (c *MetricsCollector) updateClusterMetrics(ctx *armadacontext.Context) ([]prometheus.Metric, error) { executors, err := c.executorRepository.GetExecutors(ctx) if err != nil { return nil, err diff --git a/internal/scheduler/metrics_test.go b/internal/scheduler/metrics_test.go index 52c89eb6641..0bbcd9090c7 100644 --- a/internal/scheduler/metrics_test.go +++ b/internal/scheduler/metrics_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "time" @@ -12,6 +11,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" "github.com/armadaproject/armada/internal/scheduler/database" "github.com/armadaproject/armada/internal/scheduler/jobdb" @@ -86,7 +86,7 @@ func TestMetricsCollector_TestCollect_QueueMetrics(t *testing.T) { t.Run(name, func(t *testing.T) { ctrl := gomock.NewController(t) testClock := clock.NewFakeClock(testfixtures.BaseTime) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // set up job db with initial jobs @@ -236,7 +236,7 @@ func TestMetricsCollector_TestCollect_ClusterMetrics(t *testing.T) { t.Run(name, func(t *testing.T) { ctrl := gomock.NewController(t) testClock := clock.NewFakeClock(testfixtures.BaseTime) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // set up job db with initial jobs @@ -303,7 +303,7 @@ type MockPoolAssigner struct { poolsById map[string]string } -func (m MockPoolAssigner) Refresh(_ context.Context) error { +func (m MockPoolAssigner) Refresh(_ *armadacontext.Context) error { return nil } diff --git a/internal/scheduler/mocks/mock_repositories.go b/internal/scheduler/mocks/mock_repositories.go index 9a8f6efee1a..c2924402b9b 100644 --- a/internal/scheduler/mocks/mock_repositories.go +++ b/internal/scheduler/mocks/mock_repositories.go @@ -5,10 +5,10 @@ package schedulermocks import ( - context "context" reflect "reflect" time "time" + armadacontext "github.com/armadaproject/armada/internal/common/armadacontext" database "github.com/armadaproject/armada/internal/scheduler/database" schedulerobjects "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" armadaevents "github.com/armadaproject/armada/pkg/armadaevents" @@ -40,7 +40,7 @@ func (m *MockExecutorRepository) EXPECT() *MockExecutorRepositoryMockRecorder { } // GetExecutors mocks base method. -func (m *MockExecutorRepository) GetExecutors(arg0 context.Context) ([]*schedulerobjects.Executor, error) { +func (m *MockExecutorRepository) GetExecutors(arg0 *armadacontext.Context) ([]*schedulerobjects.Executor, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetExecutors", arg0) ret0, _ := ret[0].([]*schedulerobjects.Executor) @@ -55,7 +55,7 @@ func (mr *MockExecutorRepositoryMockRecorder) GetExecutors(arg0 interface{}) *go } // GetLastUpdateTimes mocks base method. -func (m *MockExecutorRepository) GetLastUpdateTimes(arg0 context.Context) (map[string]time.Time, error) { +func (m *MockExecutorRepository) GetLastUpdateTimes(arg0 *armadacontext.Context) (map[string]time.Time, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLastUpdateTimes", arg0) ret0, _ := ret[0].(map[string]time.Time) @@ -70,7 +70,7 @@ func (mr *MockExecutorRepositoryMockRecorder) GetLastUpdateTimes(arg0 interface{ } // StoreExecutor mocks base method. -func (m *MockExecutorRepository) StoreExecutor(arg0 context.Context, arg1 *schedulerobjects.Executor) error { +func (m *MockExecutorRepository) StoreExecutor(arg0 *armadacontext.Context, arg1 *schedulerobjects.Executor) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StoreExecutor", arg0, arg1) ret0, _ := ret[0].(error) @@ -145,7 +145,7 @@ func (m *MockJobRepository) EXPECT() *MockJobRepositoryMockRecorder { } // CountReceivedPartitions mocks base method. -func (m *MockJobRepository) CountReceivedPartitions(arg0 context.Context, arg1 uuid.UUID) (uint32, error) { +func (m *MockJobRepository) CountReceivedPartitions(arg0 *armadacontext.Context, arg1 uuid.UUID) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CountReceivedPartitions", arg0, arg1) ret0, _ := ret[0].(uint32) @@ -160,7 +160,7 @@ func (mr *MockJobRepositoryMockRecorder) CountReceivedPartitions(arg0, arg1 inte } // FetchJobRunErrors mocks base method. -func (m *MockJobRepository) FetchJobRunErrors(arg0 context.Context, arg1 []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (m *MockJobRepository) FetchJobRunErrors(arg0 *armadacontext.Context, arg1 []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobRunErrors", arg0, arg1) ret0, _ := ret[0].(map[uuid.UUID]*armadaevents.Error) @@ -175,7 +175,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobRunErrors(arg0, arg1 interface{ } // FetchJobRunLeases mocks base method. -func (m *MockJobRepository) FetchJobRunLeases(arg0 context.Context, arg1 string, arg2 uint, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { +func (m *MockJobRepository) FetchJobRunLeases(arg0 *armadacontext.Context, arg1 string, arg2 uint, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobRunLeases", arg0, arg1, arg2, arg3) ret0, _ := ret[0].([]*database.JobRunLease) @@ -190,7 +190,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobRunLeases(arg0, arg1, arg2, arg } // FetchJobUpdates mocks base method. -func (m *MockJobRepository) FetchJobUpdates(arg0 context.Context, arg1, arg2 int64) ([]database.Job, []database.Run, error) { +func (m *MockJobRepository) FetchJobUpdates(arg0 *armadacontext.Context, arg1, arg2 int64) ([]database.Job, []database.Run, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobUpdates", arg0, arg1, arg2) ret0, _ := ret[0].([]database.Job) @@ -206,7 +206,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobUpdates(arg0, arg1, arg2 interf } // FindInactiveRuns mocks base method. -func (m *MockJobRepository) FindInactiveRuns(arg0 context.Context, arg1 []uuid.UUID) ([]uuid.UUID, error) { +func (m *MockJobRepository) FindInactiveRuns(arg0 *armadacontext.Context, arg1 []uuid.UUID) ([]uuid.UUID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FindInactiveRuns", arg0, arg1) ret0, _ := ret[0].([]uuid.UUID) diff --git a/internal/scheduler/pool_assigner.go b/internal/scheduler/pool_assigner.go index 94aa07e4908..9ff1f9b140c 100644 --- a/internal/scheduler/pool_assigner.go +++ b/internal/scheduler/pool_assigner.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "time" "github.com/gogo/protobuf/proto" @@ -10,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -22,7 +22,7 @@ import ( // PoolAssigner allows jobs to be assigned to a pool // Note that this is intended only for use with metrics calculation type PoolAssigner interface { - Refresh(ctx context.Context) error + Refresh(ctx *armadacontext.Context) error AssignPool(j *jobdb.Job) (string, error) } @@ -71,7 +71,7 @@ func NewPoolAssigner(executorTimeout time.Duration, } // Refresh updates executor state -func (p *DefaultPoolAssigner) Refresh(ctx context.Context) error { +func (p *DefaultPoolAssigner) Refresh(ctx *armadacontext.Context) error { executors, err := p.executorRepository.GetExecutors(ctx) executorsByPool := map[string][]*executor{} poolByExecutorId := map[string]string{} diff --git a/internal/scheduler/pool_assigner_test.go b/internal/scheduler/pool_assigner_test.go index f2508295e65..7734b6195be 100644 --- a/internal/scheduler/pool_assigner_test.go +++ b/internal/scheduler/pool_assigner_test.go @@ -1,17 +1,16 @@ package scheduler import ( - "context" "testing" "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "k8s.io/apimachinery/pkg/util/clock" - "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/jobdb" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" @@ -48,7 +47,7 @@ func TestPoolAssigner_AssignPool(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index ebf2c35b390..fd0c0d9e079 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -1,12 +1,10 @@ package scheduler import ( - "context" "fmt" "math/rand" "time" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-memdb" "github.com/pkg/errors" "golang.org/x/exp/maps" @@ -14,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" @@ -108,9 +107,7 @@ func (sch *PreemptingQueueScheduler) EnableNewPreemptionStrategy() { // Schedule // - preempts jobs belonging to queues with total allocation above their fair share and // - schedules new jobs belonging to queues with total allocation less than their fair share. -func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "PreemptingQueueScheduler") +func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { defer func() { sch.schedulingContext.Finished = time.Now() }() @@ -125,23 +122,18 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Evict preemptible jobs. totalCost := sch.schedulingContext.TotalCost() evictorResult, inMemoryJobRepo, err := sch.evict( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "evict for resource balancing"), - ), + armadacontext.WithLogField(ctx, "stage", "evict for resource balancing"), NewNodeEvictor( sch.jobRepo, sch.schedulingContext.PriorityClasses, sch.nodeEvictionProbability, - func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { + func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) return false } if job.GetNodeSelector() == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) + ctx.Log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) return false } if qctx, ok := sch.schedulingContext.QueueSchedulingContexts[job.GetQueue()]; ok { @@ -168,10 +160,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Re-schedule evicted jobs/schedule new jobs. schedulerResult, err := sch.schedule( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "re-schedule after balancing eviction"), - ), + armadacontext.WithLogField(ctx, "stage", "re-schedule after balancing eviction"), inMemoryJobRepo, sch.jobRepo, ) @@ -189,10 +178,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Evict jobs on oversubscribed nodes. evictorResult, inMemoryJobRepo, err = sch.evict( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "evict oversubscribed"), - ), + armadacontext.WithLogField(ctx, "stage", "evict oversubscribed"), NewOversubscribedEvictor( sch.jobRepo, sch.schedulingContext.PriorityClasses, @@ -226,10 +212,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Since no new jobs are considered in this round, the scheduling key check brings no benefit. sch.SkipUnsuccessfulSchedulingKeyCheck() schedulerResult, err = sch.schedule( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "schedule after oversubscribed eviction"), - ), + armadacontext.WithLogField(ctx, "stage", "schedule after oversubscribed eviction"), inMemoryJobRepo, // Only evicted jobs should be scheduled in this round, // so we provide an empty repo for queued jobs. @@ -258,10 +241,10 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe return nil, err } if s := JobsSummary(preemptedJobs); s != "" { - log.Infof("preempting running jobs; %s", s) + ctx.Log.Infof("preempting running jobs; %s", s) } if s := JobsSummary(scheduledJobs); s != "" { - log.Infof("scheduling new jobs; %s", s) + ctx.Log.Infof("scheduling new jobs; %s", s) } if sch.enableAssertions { err := sch.assertions( @@ -282,7 +265,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe }, nil } -func (sch *PreemptingQueueScheduler) evict(ctx context.Context, evictor *Evictor) (*EvictorResult, *InMemoryJobRepository, error) { +func (sch *PreemptingQueueScheduler) evict(ctx *armadacontext.Context, evictor *Evictor) (*EvictorResult, *InMemoryJobRepository, error) { if evictor == nil { return &EvictorResult{}, NewInMemoryJobRepository(sch.schedulingContext.PriorityClasses), nil } @@ -348,7 +331,7 @@ func (sch *PreemptingQueueScheduler) evict(ctx context.Context, evictor *Evictor // When evicting jobs, gangs may have been partially evicted. // Here, we evict all jobs in any gang for which at least one job was already evicted. -func (sch *PreemptingQueueScheduler) evictGangs(ctx context.Context, txn *memdb.Txn, previousEvictorResult *EvictorResult) (*EvictorResult, error) { +func (sch *PreemptingQueueScheduler) evictGangs(ctx *armadacontext.Context, txn *memdb.Txn, previousEvictorResult *EvictorResult) (*EvictorResult, error) { gangJobIds, gangNodeIds, err := sch.collectIdsForGangEviction(previousEvictorResult.EvictedJobsById) if err != nil { return nil, err @@ -512,7 +495,7 @@ func (q MinimalQueue) GetWeight() float64 { // addEvictedJobsToNodeDb adds evicted jobs to the NodeDb. // Needed to enable the nodeDb accounting for these when preempting. -func addEvictedJobsToNodeDb(ctx context.Context, sctx *schedulercontext.SchedulingContext, nodeDb *nodedb.NodeDb, inMemoryJobRepo *InMemoryJobRepository) error { +func addEvictedJobsToNodeDb(ctx *armadacontext.Context, sctx *schedulercontext.SchedulingContext, nodeDb *nodedb.NodeDb, inMemoryJobRepo *InMemoryJobRepository) error { gangItByQueue := make(map[string]*QueuedGangIterator) for _, qctx := range sctx.QueueSchedulingContexts { jobIt, err := inMemoryJobRepo.GetJobIterator(ctx, qctx.Queue) @@ -552,7 +535,7 @@ func addEvictedJobsToNodeDb(ctx context.Context, sctx *schedulercontext.Scheduli return nil } -func (sch *PreemptingQueueScheduler) schedule(ctx context.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*SchedulerResult, error) { +func (sch *PreemptingQueueScheduler) schedule(ctx *armadacontext.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*SchedulerResult, error) { jobIteratorByQueue := make(map[string]JobIterator) for _, qctx := range sch.schedulingContext.QueueSchedulingContexts { evictedIt, err := inMemoryJobRepo.GetJobIterator(ctx, qctx.Queue) @@ -717,9 +700,9 @@ func (sch *PreemptingQueueScheduler) assertions( type Evictor struct { jobRepo JobRepository priorityClasses map[string]types.PriorityClass - nodeFilter func(context.Context, *nodedb.Node) bool - jobFilter func(context.Context, interfaces.LegacySchedulerJob) bool - postEvictFunc func(context.Context, interfaces.LegacySchedulerJob, *nodedb.Node) + nodeFilter func(*armadacontext.Context, *nodedb.Node) bool + jobFilter func(*armadacontext.Context, interfaces.LegacySchedulerJob) bool + postEvictFunc func(*armadacontext.Context, interfaces.LegacySchedulerJob, *nodedb.Node) } type EvictorResult struct { @@ -735,7 +718,7 @@ func NewNodeEvictor( jobRepo JobRepository, priorityClasses map[string]types.PriorityClass, perNodeEvictionProbability float64, - jobFilter func(context.Context, interfaces.LegacySchedulerJob) bool, + jobFilter func(*armadacontext.Context, interfaces.LegacySchedulerJob) bool, random *rand.Rand, ) *Evictor { if perNodeEvictionProbability <= 0 { @@ -747,7 +730,7 @@ func NewNodeEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { return len(node.AllocatedByJobId) > 0 && random.Float64() < perNodeEvictionProbability }, jobFilter: jobFilter, @@ -769,11 +752,11 @@ func NewFilteredEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { shouldEvict := nodeIdsToEvict[node.Id] return shouldEvict }, - jobFilter: func(_ context.Context, job interfaces.LegacySchedulerJob) bool { + jobFilter: func(_ *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { shouldEvict := jobIdsToEvict[job.GetId()] return shouldEvict }, @@ -804,7 +787,7 @@ func NewOversubscribedEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { overSubscribedPriorities = make(map[int32]bool) for p, rl := range node.AllocatableByPriority { if p < 0 { @@ -820,10 +803,9 @@ func NewOversubscribedEvictor( } return len(overSubscribedPriorities) > 0 && random.Float64() < perNodeEvictionProbability }, - jobFilter: func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { + jobFilter: func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - log := ctxlogrus.Extract(ctx) - log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) return false } priorityClassName := job.GetPriorityClassName() @@ -844,7 +826,7 @@ func NewOversubscribedEvictor( // Any node for which nodeFilter returns false is skipped. // Any job for which jobFilter returns true is evicted (if the node was not skipped). // If a job was evicted from a node, postEvictFunc is called with the corresponding job and node. -func (evi *Evictor) Evict(ctx context.Context, it nodedb.NodeIterator) (*EvictorResult, error) { +func (evi *Evictor) Evict(ctx *armadacontext.Context, it nodedb.NodeIterator) (*EvictorResult, error) { var jobFilter func(job interfaces.LegacySchedulerJob) bool if evi.jobFilter != nil { jobFilter = func(job interfaces.LegacySchedulerJob) bool { return evi.jobFilter(ctx, job) } @@ -898,12 +880,11 @@ func (evi *Evictor) Evict(ctx context.Context, it nodedb.NodeIterator) (*Evictor // TODO: This is only necessary for jobs not scheduled in this cycle. // Since jobs scheduled in this cycle can be re-scheduled onto another node without triggering a preemption. -func defaultPostEvictFunc(ctx context.Context, job interfaces.LegacySchedulerJob, node *nodedb.Node) { +func defaultPostEvictFunc(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob, node *nodedb.Node) { // Add annotation indicating to the scheduler this this job was evicted. annotations := job.GetAnnotations() if annotations == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("error evicting job %s: annotations not initialised", job.GetId()) + ctx.Log.Errorf("error evicting job %s: annotations not initialised", job.GetId()) } else { annotations[schedulerconfig.IsEvictedAnnotation] = "true" } @@ -911,8 +892,7 @@ func defaultPostEvictFunc(ctx context.Context, job interfaces.LegacySchedulerJob // Add node selector ensuring this job is only re-scheduled onto the node it was evicted from. nodeSelector := job.GetNodeSelector() if nodeSelector == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) + ctx.Log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) } else { nodeSelector[schedulerconfig.NodeIdLabel] = node.Id } diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index dc84cd225ae..84538cdccc2 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -1,13 +1,11 @@ package scheduler import ( - "context" "fmt" "math/rand" "testing" "time" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" @@ -55,7 +54,7 @@ func TestEvictOversubscribed(t *testing.T) { nil, ) it := NewInMemoryNodeIterator([]*nodedb.Node{entry}) - result, err := evictor.Evict(context.Background(), it) + result, err := evictor.Evict(armadacontext.Background(), it) require.NoError(t, err) prioritiesByName := configuration.PriorityByPriorityClassName(testfixtures.TestPriorityClasses) @@ -1459,7 +1458,7 @@ func TestPreemptingQueueScheduler(t *testing.T) { if tc.SchedulingConfig.EnableNewPreemptionStrategy { sch.EnableNewPreemptionStrategy() } - result, err := sch.Schedule(ctxlogrus.ToContext(context.Background(), log)) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(t, err) jobIdsByGangId = sch.jobIdsByGangId gangIdByJobId = sch.gangIdByJobId @@ -1734,7 +1733,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nil, nil, ) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(b, err) require.Equal(b, 0, len(result.PreemptedJobs)) @@ -1790,7 +1789,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nil, nil, ) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(b, err) // We expect the system to be in steady-state, i.e., no preempted/scheduled jobs. diff --git a/internal/scheduler/proxying_reports_server_test.go b/internal/scheduler/proxying_reports_server_test.go index 0dc81b54bf9..98f7c11fa97 100644 --- a/internal/scheduler/proxying_reports_server_test.go +++ b/internal/scheduler/proxying_reports_server_test.go @@ -1,13 +1,13 @@ package scheduler import ( - "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -24,7 +24,7 @@ func TestProxyingSchedulingReportsServer_GetJobReports(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) @@ -62,7 +62,7 @@ func TestProxyingSchedulingReportsServer_GetSchedulingReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) @@ -100,7 +100,7 @@ func TestProxyingSchedulingReportsServer_GetQueueReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) diff --git a/internal/scheduler/publisher.go b/internal/scheduler/publisher.go index 0ae0595303b..0b308141961 100644 --- a/internal/scheduler/publisher.go +++ b/internal/scheduler/publisher.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "strconv" "sync" @@ -13,6 +12,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" @@ -28,12 +28,12 @@ const ( type Publisher interface { // PublishMessages will publish the supplied messages. A LeaderToken is provided and the // implementor may decide whether to publish based on the status of this token - PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error + PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error // PublishMarkers publishes a single marker message for each Pulsar partition. Each marker // massage contains the supplied group id, which allows all marker messages for a given call // to be identified. The uint32 returned is the number of messages published - PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) + PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) } // PulsarPublisher is the default implementation of Publisher @@ -77,7 +77,7 @@ func NewPulsarPublisher( // PublishMessages publishes all event sequences to pulsar. Event sequences for a given jobset will be combined into // single event sequences up to maxMessageBatchSize. -func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error { +func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error { sequences := eventutil.CompactEventSequences(events) sequences, err := eventutil.LimitSequencesByteSize(sequences, p.maxMessageBatchSize, true) if err != nil { @@ -104,7 +104,7 @@ func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadae // Send messages if shouldPublish() { log.Debugf("Am leader so will publish") - sendCtx, cancel := context.WithTimeout(ctx, p.pulsarSendTimeout) + sendCtx, cancel := armadacontext.WithTimeout(ctx, p.pulsarSendTimeout) errored := false for _, msg := range msgs { p.producer.SendAsync(sendCtx, msg, func(_ pulsar.MessageID, _ *pulsar.ProducerMessage, err error) { @@ -128,7 +128,7 @@ func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadae // PublishMarkers sends one pulsar message (containing an armadaevents.PartitionMarker) to each partition // of the producer's Pulsar topic. -func (p *PulsarPublisher) PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (p *PulsarPublisher) PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { for i := 0; i < p.numPartitions; i++ { pm := &armadaevents.PartitionMarker{ GroupId: armadaevents.ProtoUuidFromUuid(groupId), diff --git a/internal/scheduler/publisher_test.go b/internal/scheduler/publisher_test.go index a524f9e26b9..6ecb200d416 100644 --- a/internal/scheduler/publisher_test.go +++ b/internal/scheduler/publisher_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "math" "testing" @@ -15,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/mocks" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/pkg/armadaevents" @@ -89,7 +89,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) mockPulsarClient := mocks.NewMockClient(ctrl) @@ -106,7 +106,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { mockPulsarProducer. EXPECT(). SendAsync(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, msg *pulsar.ProducerMessage, callback func(pulsar.MessageID, *pulsar.ProducerMessage, error)) { + DoAndReturn(func(_ *armadacontext.Context, msg *pulsar.ProducerMessage, callback func(pulsar.MessageID, *pulsar.ProducerMessage, error)) { es := &armadaevents.EventSequence{} err := proto.Unmarshal(msg.Payload, es) require.NoError(t, err) @@ -177,7 +177,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { mockPulsarProducer. EXPECT(). Send(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) { + DoAndReturn(func(_ *armadacontext.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) { numPublished++ key, ok := msg.Properties[explicitPartitionKey] if ok { @@ -190,7 +190,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { }).AnyTimes() options := pulsar.ProducerOptions{Topic: topic} - ctx := context.TODO() + ctx := armadacontext.TODO() publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) require.NoError(t, err) diff --git a/internal/scheduler/queue_scheduler.go b/internal/scheduler/queue_scheduler.go index 825c9f26bfb..cf03c7af3fc 100644 --- a/internal/scheduler/queue_scheduler.go +++ b/internal/scheduler/queue_scheduler.go @@ -2,13 +2,13 @@ package scheduler import ( "container/heap" - "context" "reflect" "time" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -60,7 +60,7 @@ func (sch *QueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { sch.gangScheduler.SkipUnsuccessfulSchedulingKeyCheck() } -func (sch *QueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { +func (sch *QueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { nodeIdByJobId := make(map[string]string) scheduledJobs := make([]interfaces.LegacySchedulerJob, 0) for { diff --git a/internal/scheduler/queue_scheduler_test.go b/internal/scheduler/queue_scheduler_test.go index cbbc537e495..3832db7ceba 100644 --- a/internal/scheduler/queue_scheduler_test.go +++ b/internal/scheduler/queue_scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "testing" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" @@ -512,14 +512,14 @@ func TestQueueScheduler(t *testing.T) { ) jobIteratorByQueue := make(map[string]JobIterator) for queue := range tc.PriorityFactorByQueue { - it, err := jobRepo.GetJobIterator(context.Background(), queue) + it, err := jobRepo.GetJobIterator(armadacontext.Background(), queue) require.NoError(t, err) jobIteratorByQueue[queue] = it } sch, err := NewQueueScheduler(sctx, constraints, nodeDb, jobIteratorByQueue) require.NoError(t, err) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(t, err) // Check that the right jobs got scheduled. diff --git a/internal/scheduler/reports_test.go b/internal/scheduler/reports_test.go index b3c8f568d38..fcc0837188a 100644 --- a/internal/scheduler/reports_test.go +++ b/internal/scheduler/reports_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" @@ -159,7 +159,7 @@ func TestAddGetSchedulingContext(t *testing.T) { func TestTestAddGetSchedulingContextConcurrency(t *testing.T) { repo, err := NewSchedulingContextRepository(10) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Second) defer cancel() for _, executorId := range []string{"foo", "bar"} { go func(executorId string) { @@ -202,7 +202,7 @@ func TestReportDoesNotExist(t *testing.T) { require.NoError(t, err) err = repo.AddSchedulingContext(testSchedulingContext("executor-01")) require.NoError(t, err) - ctx := context.Background() + ctx := armadacontext.Background() queue := "queue-does-not-exist" jobId := util.NewULID() diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 88c43b14a0f..ccc4d998ff5 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -1,19 +1,18 @@ package scheduler import ( - "context" "fmt" "time" "github.com/gogo/protobuf/proto" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/scheduler/database" @@ -116,40 +115,37 @@ func NewScheduler( } // Run enters the scheduling loop, which will continue until ctx is cancelled. -func (s *Scheduler) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "scheduler") - ctx = ctxlogrus.ToContext(ctx, log) - log.Infof("starting scheduler with cycle time %s", s.cyclePeriod) - defer log.Info("scheduler stopped") +func (s *Scheduler) Run(ctx *armadacontext.Context) error { + ctx.Log.Infof("starting scheduler with cycle time %s", s.cyclePeriod) + defer ctx.Log.Info("scheduler stopped") // JobDb initialisation. start := s.clock.Now() if err := s.initialise(ctx); err != nil { return err } - log.Infof("JobDb initialised in %s", s.clock.Since(start)) + ctx.Log.Infof("JobDb initialised in %s", s.clock.Since(start)) ticker := s.clock.NewTicker(s.cyclePeriod) prevLeaderToken := InvalidLeaderToken() for { select { case <-ctx.Done(): - log.Infof("context cancelled; returning.") + ctx.Log.Infof("context cancelled; returning.") return ctx.Err() case <-ticker.C(): start := s.clock.Now() leaderToken := s.leaderController.GetToken() fullUpdate := false - log.Infof("received leaderToken; leader status is %t", leaderToken.leader) + ctx.Log.Infof("received leaderToken; leader status is %t", leaderToken.leader) // If we are becoming leader then we must ensure we have caught up to all Pulsar messages if leaderToken.leader && leaderToken != prevLeaderToken { - log.Infof("becoming leader") - syncContext, cancel := context.WithTimeout(ctx, 5*time.Minute) + ctx.Log.Infof("becoming leader") + syncContext, cancel := armadacontext.WithTimeout(ctx, 5*time.Minute) err := s.ensureDbUpToDate(syncContext, 1*time.Second) if err != nil { - log.WithError(err).Error("could not become leader") + logging.WithStacktrace(ctx.Log, err).Error("could not become leader") leaderToken = InvalidLeaderToken() } else { fullUpdate = true @@ -169,7 +165,7 @@ func (s *Scheduler) Run(ctx context.Context) error { result, err := s.cycle(ctx, fullUpdate, leaderToken, shouldSchedule) if err != nil { - logging.WithStacktrace(log, err).Error("scheduling cycle failure") + logging.WithStacktrace(ctx.Log, err).Error("scheduling cycle failure") leaderToken = InvalidLeaderToken() } @@ -181,10 +177,10 @@ func (s *Scheduler) Run(ctx context.Context) error { // Only the leader token does real scheduling rounds. s.metrics.ReportScheduleCycleTime(cycleTime) s.metrics.ReportSchedulerResult(result) - log.Infof("scheduling cycle completed in %s", cycleTime) + ctx.Log.Infof("scheduling cycle completed in %s", cycleTime) } else { s.metrics.ReportReconcileCycleTime(cycleTime) - log.Infof("reconciliation cycle completed in %s", cycleTime) + ctx.Log.Infof("reconciliation cycle completed in %s", cycleTime) } prevLeaderToken = leaderToken @@ -198,11 +194,9 @@ func (s *Scheduler) Run(ctx context.Context) error { // cycle is a single iteration of the main scheduling loop. // If updateAll is true, we generate events from all jobs in the jobDb. // Otherwise, we only generate events from jobs updated since the last cycle. -func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken LeaderToken, shouldSchedule bool) (overallSchedulerResult SchedulerResult, err error) { +func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToken LeaderToken, shouldSchedule bool) (overallSchedulerResult SchedulerResult, err error) { overallSchedulerResult = SchedulerResult{EmptyResult: true} - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "cycle") // Update job state. updatedJobs, err := s.syncState(ctx) if err != nil { @@ -244,7 +238,7 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade } var resultEvents []*armadaevents.EventSequence - resultEvents, err = s.eventsFromSchedulerResult(txn, result) + resultEvents, err = s.eventsFromSchedulerResult(result) if err != nil { return } @@ -262,22 +256,19 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade if err = s.publisher.PublishMessages(ctx, events, isLeader); err != nil { return } - log.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) + ctx.Log.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) txn.Commit() return } // syncState updates jobs in jobDb to match state in postgres and returns all updated jobs. -func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "syncState") - +func (s *Scheduler) syncState(ctx *armadacontext.Context) ([]*jobdb.Job, error) { start := s.clock.Now() updatedJobs, updatedRuns, err := s.jobRepository.FetchJobUpdates(ctx, s.jobsSerial, s.runsSerial) if err != nil { return nil, err } - log.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) + ctx.Log.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) txn := s.jobDb.WriteTxn() defer txn.Abort() @@ -321,7 +312,7 @@ func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) { // If the job is nil or terminal at this point then it cannot be active. // In this case we can ignore the run. if job == nil || job.InTerminalState() { - log.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) + ctx.Log.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) continue } } @@ -391,7 +382,7 @@ func (s *Scheduler) addNodeAntiAffinitiesForAttemptedRunsIfSchedulable(job *jobd } // eventsFromSchedulerResult generates necessary EventSequences from the provided SchedulerResult. -func (s *Scheduler) eventsFromSchedulerResult(txn *jobdb.Txn, result *SchedulerResult) ([]*armadaevents.EventSequence, error) { +func (s *Scheduler) eventsFromSchedulerResult(result *SchedulerResult) ([]*armadaevents.EventSequence, error) { return EventsFromSchedulerResult(result, s.clock.Now()) } @@ -507,7 +498,7 @@ func AppendEventSequencesFromScheduledJobs(eventSequences []*armadaevents.EventS // generateUpdateMessages generates EventSequences representing the state changes on updated jobs // If there are no state changes then an empty slice will be returned -func (s *Scheduler) generateUpdateMessages(ctx context.Context, updatedJobs []*jobdb.Job, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { +func (s *Scheduler) generateUpdateMessages(ctx *armadacontext.Context, updatedJobs []*jobdb.Job, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { failedRunIds := make([]uuid.UUID, 0, len(updatedJobs)) for _, job := range updatedJobs { run := job.LatestRun() @@ -708,10 +699,7 @@ func (s *Scheduler) generateUpdateMessagesFromJob(job *jobdb.Job, jobRunErrors m // expireJobsIfNecessary removes any jobs from the JobDb which are running on stale executors. // It also generates an EventSequence for each job, indicating that both the run and the job has failed // Note that this is different behaviour from the old scheduler which would allow expired jobs to be rerun -func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "expireJobsIfNecessary") - +func (s *Scheduler) expireJobsIfNecessary(ctx *armadacontext.Context, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { heartbeatTimes, err := s.executorRepository.GetLastUpdateTimes(ctx) if err != nil { return nil, err @@ -726,14 +714,14 @@ func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ( // has been completely removed for executor, heartbeat := range heartbeatTimes { if heartbeat.Before(cutOff) { - log.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) + ctx.Log.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) staleExecutors[executor] = true } } // All clusters have had a heartbeat recently. No need to expire any jobs if len(staleExecutors) == 0 { - log.Infof("No stale executors found. No jobs need to be expired") + ctx.Log.Infof("No stale executors found. No jobs need to be expired") return nil, nil } @@ -750,7 +738,7 @@ func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ( run := job.LatestRun() if run != nil && !job.Queued() && staleExecutors[run.Executor()] { - log.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) + ctx.Log.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) jobsToUpdate = append(jobsToUpdate, job.WithQueued(false).WithFailed(true).WithUpdatedRun(run.WithFailed(true))) jobId, err := armadaevents.ProtoUuidFromUlidString(job.Id()) @@ -808,16 +796,14 @@ func (s *Scheduler) now() *time.Time { // initialise builds the initial job db based on the current database state // right now this is quite dim and loads the entire database but in the future // we should be able to make it load active jobs/runs only -func (s *Scheduler) initialise(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "initialise") +func (s *Scheduler) initialise(ctx *armadacontext.Context) error { for { select { case <-ctx.Done(): return nil default: if _, err := s.syncState(ctx); err != nil { - log.WithError(err).Error("failed to initialise; trying again in 1 second") + ctx.Log.WithError(err).Error("failed to initialise; trying again in 1 second") time.Sleep(1 * time.Second) } else { // Initialisation succeeded. @@ -830,10 +816,7 @@ func (s *Scheduler) initialise(ctx context.Context) error { // ensureDbUpToDate blocks until that the database state contains all Pulsar messages sent *before* this // function was called. This is achieved firstly by publishing messages to Pulsar and then polling the // database until all messages have been written. -func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Duration) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "ensureDbUpToDate") - +func (s *Scheduler) ensureDbUpToDate(ctx *armadacontext.Context, pollInterval time.Duration) error { groupId := uuid.New() var numSent uint32 var err error @@ -847,7 +830,7 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura default: numSent, err = s.publisher.PublishMarkers(ctx, groupId) if err != nil { - log.WithError(err).Error("Error sending marker messages to pulsar") + ctx.Log.WithError(err).Error("Error sending marker messages to pulsar") s.clock.Sleep(pollInterval) } else { messagesSent = true @@ -863,13 +846,13 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura default: numReceived, err := s.jobRepository.CountReceivedPartitions(ctx, groupId) if err != nil { - log.WithError(err).Error("Error querying the database or marker messages") + ctx.Log.WithError(err).Error("Error querying the database or marker messages") } if numSent == numReceived { - log.Infof("Successfully ensured that database state is up to date") + ctx.Log.Infof("Successfully ensured that database state is up to date") return nil } - log.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) + ctx.Log.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) s.clock.Sleep(pollInterval) } } diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 1db7ad4ae8f..584f4552d42 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "sync" "testing" @@ -15,6 +14,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/common/util" @@ -527,7 +527,7 @@ func TestScheduler_TestCycle(t *testing.T) { txn.Commit() // run a scheduler cycle - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) _, err = sched.cycle(ctx, false, sched.leaderController.GetToken(), true) if tc.fetchError || tc.publishError || tc.scheduleError { assert.Error(t, err) @@ -684,7 +684,7 @@ func TestRun(t *testing.T) { sched.clock = testClock - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) //nolint:errcheck go sched.Run(ctx) @@ -861,7 +861,7 @@ func TestScheduler_TestSyncState(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // Test objects @@ -943,31 +943,31 @@ type testJobRepository struct { numReceivedPartitions uint32 } -func (t *testJobRepository) FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { +func (t *testJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { // TODO implement me panic("implement me") } -func (t *testJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { +func (t *testJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { // TODO implement me panic("implement me") } -func (t *testJobRepository) FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]database.Job, []database.Run, error) { +func (t *testJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]database.Job, []database.Run, error) { if t.shouldError { return nil, nil, errors.New("error fetchiung job updates") } return t.updatedJobs, t.updatedRuns, nil } -func (t *testJobRepository) FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (t *testJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { if t.shouldError { return nil, errors.New("error fetching job run errors") } return t.errors, nil } -func (t *testJobRepository) CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (t *testJobRepository) CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { if t.shouldError { return 0, errors.New("error counting received partitions") } @@ -979,18 +979,18 @@ type testExecutorRepository struct { shouldError bool } -func (t testExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (t testExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { panic("not implemented") } -func (t testExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (t testExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { if t.shouldError { return nil, errors.New("error getting last update time") } return t.updateTimes, nil } -func (t testExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (t testExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { panic("not implemented") } @@ -1001,7 +1001,7 @@ type testSchedulingAlgo struct { shouldError bool } -func (t *testSchedulingAlgo) Schedule(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) { +func (t *testSchedulingAlgo) Schedule(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) { t.numberOfScheduleCalls++ if t.shouldError { return nil, errors.New("error scheduling jobs") @@ -1049,7 +1049,7 @@ type testPublisher struct { shouldError bool } -func (t *testPublisher) PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, _ func() bool) error { +func (t *testPublisher) PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, _ func() bool) error { t.events = events if t.shouldError { return errors.New("Error when publishing") @@ -1061,7 +1061,7 @@ func (t *testPublisher) Reset() { t.events = nil } -func (t *testPublisher) PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (t *testPublisher) PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { return 100, nil } diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index ef742c3dc24..9ba1302c920 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -10,17 +10,16 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/go-redis/redis" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" "github.com/armadaproject/armada/internal/common" "github.com/armadaproject/armada/internal/common/app" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth" dbcommon "github.com/armadaproject/armada/internal/common/database" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" @@ -35,9 +34,7 @@ import ( // Run sets up a Scheduler application and runs it until a SIGTERM is received func Run(config schedulerconfig.Configuration) error { - g, ctx := errgroup.WithContext(app.CreateContextWithShutdown()) - logrusLogger := log.NewEntry(log.StandardLogger()) - ctx = ctxlogrus.ToContext(ctx, logrusLogger) + g, ctx := armadacontext.ErrGroup(app.CreateContextWithShutdown()) ////////////////////////////////////////////////////////////////////////// // Health Checks diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 11f6b667f96..a1865d1601b 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -7,7 +7,6 @@ import ( "github.com/benbjohnson/immutable" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -16,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" @@ -34,7 +34,7 @@ import ( type SchedulingAlgo interface { // Schedule should assign jobs to nodes. // Any jobs that are scheduled should be marked as such in the JobDb using the transaction provided. - Schedule(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) + Schedule(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) } // FairSchedulingAlgo is a SchedulingAlgo based on PreemptingQueueScheduler. @@ -88,12 +88,10 @@ func NewFairSchedulingAlgo( // It maintains state of which executors it has considered already and may take multiple Schedule() calls to consider all executors if scheduling is slow. // Newly leased jobs are updated as such in the jobDb using the transaction provided and are also returned to the caller. func (l *FairSchedulingAlgo) Schedule( - ctx context.Context, + ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb, ) (*SchedulerResult, error) { - log := ctxlogrus.Extract(ctx) - overallSchedulerResult := &SchedulerResult{ NodeIdByJobId: make(map[string]string), SchedulingContexts: make([]*schedulercontext.SchedulingContext, 0, 0), @@ -101,11 +99,11 @@ func (l *FairSchedulingAlgo) Schedule( // Exit immediately if scheduling is disabled. if l.schedulingConfig.DisableScheduling { - log.Info("skipping scheduling - scheduling disabled") + ctx.Log.Info("skipping scheduling - scheduling disabled") return overallSchedulerResult, nil } - ctxWithTimeout, cancel := context.WithTimeout(ctx, l.maxSchedulingDuration) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, l.maxSchedulingDuration) defer cancel() fsctx, err := l.newFairSchedulingAlgoContext(ctx, txn, jobDb) @@ -123,7 +121,7 @@ func (l *FairSchedulingAlgo) Schedule( select { case <-ctxWithTimeout.Done(): // We've reached the scheduling time limit; exit gracefully. - log.Info("ending scheduling round early as we have hit the maximum scheduling duration") + ctx.Log.Info("ending scheduling round early as we have hit the maximum scheduling duration") return overallSchedulerResult, nil default: } @@ -142,7 +140,7 @@ func (l *FairSchedulingAlgo) Schedule( // Assume pool and minimumJobSize are consistent within the group. pool := executorGroup[0].Pool minimumJobSize := executorGroup[0].MinimumJobSize - log.Infof( + ctx.Log.Infof( "scheduling on executor group %s with capacity %s", executorGroupLabel, fsctx.totalCapacityByPool[pool].CompactString(), ) @@ -158,14 +156,14 @@ func (l *FairSchedulingAlgo) Schedule( // add the executorGroupLabel back to l.executorGroupsToSchedule such that we try it again next time, // and exit gracefully. l.executorGroupsToSchedule = append(l.executorGroupsToSchedule, executorGroupLabel) - log.Info("stopped scheduling early as we have hit the maximum scheduling duration") + ctx.Log.Info("stopped scheduling early as we have hit the maximum scheduling duration") break } else if err != nil { return nil, err } if l.schedulingContextRepository != nil { if err := l.schedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(log, err).Error("failed to add scheduling context") + logging.WithStacktrace(ctx.Log, err).Error("failed to add scheduling context") } } @@ -239,7 +237,7 @@ type fairSchedulingAlgoContext struct { jobDb *jobdb.JobDb } -func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*fairSchedulingAlgoContext, error) { +func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*fairSchedulingAlgoContext, error) { executors, err := l.executorRepository.GetExecutors(ctx) if err != nil { return nil, err @@ -330,7 +328,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx context.Context, t // scheduleOnExecutors schedules jobs on a specified set of executors. func (l *FairSchedulingAlgo) scheduleOnExecutors( - ctx context.Context, + ctx *armadacontext.Context, fsctx *fairSchedulingAlgoContext, pool string, minimumJobSize schedulerobjects.ResourceList, @@ -556,17 +554,16 @@ func (l *FairSchedulingAlgo) filterStaleExecutors(executors []*schedulerobjects. // // TODO: Let's also check that jobs are on the right nodes. func (l *FairSchedulingAlgo) filterLaggingExecutors( - ctx context.Context, + ctx *armadacontext.Context, executors []*schedulerobjects.Executor, leasedJobsByExecutor map[string][]*jobdb.Job, ) []*schedulerobjects.Executor { - log := ctxlogrus.Extract(ctx) activeExecutors := make([]*schedulerobjects.Executor, 0, len(executors)) for _, executor := range executors { leasedJobs := leasedJobsByExecutor[executor.Id] executorRuns, err := executor.AllRuns() if err != nil { - logging.WithStacktrace(log, err).Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) + logging.WithStacktrace(ctx.Log, err).Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) continue } executorRunIds := make(map[uuid.UUID]bool, len(executorRuns)) @@ -585,7 +582,7 @@ func (l *FairSchedulingAlgo) filterLaggingExecutors( if numUnacknowledgedJobs <= l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor { activeExecutors = append(activeExecutors, executor) } else { - log.Warnf( + ctx.Log.Warnf( "%d unacknowledged jobs on executor %s exceeds limit of %d; executor will not be considered for scheduling", numUnacknowledgedJobs, executor.Id, l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor, ) diff --git a/internal/scheduler/scheduling_algo_test.go b/internal/scheduler/scheduling_algo_test.go index 6cb6a276f6a..2bf766ecd40 100644 --- a/internal/scheduler/scheduling_algo_test.go +++ b/internal/scheduler/scheduling_algo_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "math" "testing" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/scheduler/database" "github.com/armadaproject/armada/internal/scheduler/jobdb" @@ -330,9 +330,8 @@ func TestSchedule(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx := testfixtures.ContextWithDefaultLogger(context.Background()) timeout := 5 * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/simulator/simulator.go b/internal/scheduler/simulator/simulator.go index 639f617c275..94fa8989b84 100644 --- a/internal/scheduler/simulator/simulator.go +++ b/internal/scheduler/simulator/simulator.go @@ -3,20 +3,17 @@ package simulator import ( "bytes" "container/heap" - "context" - fmt "fmt" + "fmt" "os" "path/filepath" "strings" "time" "github.com/caarlos0/log" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/mattn/go-zglob" "github.com/oklog/ulid" "github.com/pkg/errors" "github.com/renstrom/shortuuid" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -25,6 +22,7 @@ import ( "k8s.io/apimachinery/pkg/util/yaml" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" commonconfig "github.com/armadaproject/armada/internal/common/config" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" @@ -34,7 +32,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/fairness" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/nodedb" - schedulerobjects "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/internal/scheduleringester" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -472,7 +470,7 @@ func (s *Simulator) handleScheduleEvent() error { if s.schedulingConfig.EnableNewPreemptionStrategy { sch.EnableNewPreemptionStrategy() } - ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New())) + ctx := armadacontext.Background() result, err := sch.Schedule(ctx) if err != nil { return err @@ -775,7 +773,7 @@ func (s *Simulator) handleJobRunPreempted(txn *jobdb.Txn, e *armadaevents.JobRun return true, nil } -// func (a *App) TestPattern(ctx context.Context, pattern string) (*TestSuiteReport, error) { +// func (a *App) TestPattern(ctx *context.Context, pattern string) (*TestSuiteReport, error) { // testSpecs, err := TestSpecsFromPattern(pattern) // if err != nil { // return nil, err diff --git a/internal/scheduler/submitcheck.go b/internal/scheduler/submitcheck.go index 6221e2611e9..bf79e0eb317 100644 --- a/internal/scheduler/submitcheck.go +++ b/internal/scheduler/submitcheck.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "strings" "sync" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -84,7 +84,7 @@ func NewSubmitChecker( } } -func (srv *SubmitChecker) Run(ctx context.Context) error { +func (srv *SubmitChecker) Run(ctx *armadacontext.Context) error { srv.updateExecutors(ctx) ticker := time.NewTicker(srv.ExecutorUpdateFrequency) @@ -98,7 +98,7 @@ func (srv *SubmitChecker) Run(ctx context.Context) error { } } -func (srv *SubmitChecker) updateExecutors(ctx context.Context) { +func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) { executors, err := srv.executorRepository.GetExecutors(ctx) if err != nil { log.WithError(err).Error("Error fetching executors") diff --git a/internal/scheduler/submitcheck_test.go b/internal/scheduler/submitcheck_test.go index a95f3d9abbf..87be5674bf8 100644 --- a/internal/scheduler/submitcheck_test.go +++ b/internal/scheduler/submitcheck_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "time" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/jobdb" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" @@ -72,7 +72,7 @@ func TestSubmitChecker_CheckJobDbJobs(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) @@ -170,7 +170,7 @@ func TestSubmitChecker_TestCheckApiJobs(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/testfixtures/testfixtures.go b/internal/scheduler/testfixtures/testfixtures.go index 0acda7d60a6..e73d246c74a 100644 --- a/internal/scheduler/testfixtures/testfixtures.go +++ b/internal/scheduler/testfixtures/testfixtures.go @@ -2,16 +2,13 @@ package testfixtures // This file contains test fixtures to be used throughout the tests for this package. import ( - "context" "fmt" "math" "sync/atomic" "time" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/oklog/ulid" - "github.com/sirupsen/logrus" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -82,10 +79,6 @@ func Repeat[T any](v T, n int) []T { return rv } -func ContextWithDefaultLogger(ctx context.Context) context.Context { - return ctxlogrus.ToContext(ctx, logrus.NewEntry(logrus.New())) -} - func TestSchedulingConfig() configuration.SchedulingConfig { return configuration.SchedulingConfig{ ResourceScarcity: map[string]float64{"cpu": 1}, diff --git a/internal/scheduleringester/instructions.go b/internal/scheduleringester/instructions.go index 429ab2d9112..4a8bc70fd51 100644 --- a/internal/scheduleringester/instructions.go +++ b/internal/scheduleringester/instructions.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "time" "github.com/gogo/protobuf/proto" @@ -10,6 +9,7 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/ingest/metrics" @@ -46,7 +46,7 @@ func NewInstructionConverter( } } -func (c *InstructionConverter) Convert(_ context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *DbOperationsWithMessageIds { +func (c *InstructionConverter) Convert(_ *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *DbOperationsWithMessageIds { operations := make([]DbOperation, 0) for _, es := range sequencesWithIds.EventSequences { for _, op := range c.dbOperationsFromEventSequence(es) { diff --git a/internal/scheduleringester/schedulerdb.go b/internal/scheduleringester/schedulerdb.go index e1ce855504b..058f0f4778b 100644 --- a/internal/scheduleringester/schedulerdb.go +++ b/internal/scheduleringester/schedulerdb.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "time" "github.com/google/uuid" @@ -10,6 +9,7 @@ import ( "github.com/pkg/errors" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/ingest/metrics" @@ -45,14 +45,14 @@ func NewSchedulerDb( // Store persists all operations in the database. // This function retires until it either succeeds or encounters a terminal error. // This function locks the postgres table to avoid write conflicts; see acquireLock() for details. -func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithMessageIds) error { +func (s *SchedulerDb) Store(ctx *armadacontext.Context, instructions *DbOperationsWithMessageIds) error { return ingest.WithRetry(func() (bool, error) { err := pgx.BeginTxFunc(ctx, s.db, pgx.TxOptions{ IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadWrite, DeferrableMode: pgx.Deferrable, }, func(tx pgx.Tx) error { - lockCtx, cancel := context.WithTimeout(ctx, s.lockTimeout) + lockCtx, cancel := armadacontext.WithTimeout(ctx, s.lockTimeout) defer cancel() // The lock is released automatically on transaction rollback/commit. if err := s.acquireLock(lockCtx, tx); err != nil { @@ -78,7 +78,7 @@ func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithM // rows with sequence numbers smaller than those already written. // // The scheduler relies on these sequence numbers to only fetch new or updated rows in each update cycle. -func (s *SchedulerDb) acquireLock(ctx context.Context, tx pgx.Tx) error { +func (s *SchedulerDb) acquireLock(ctx *armadacontext.Context, tx pgx.Tx) error { const lockId = 8741339439634283896 if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", lockId); err != nil { return errors.Wrapf(err, "could not obtain lock") @@ -86,7 +86,7 @@ func (s *SchedulerDb) acquireLock(ctx context.Context, tx pgx.Tx) error { return nil } -func (s *SchedulerDb) WriteDbOp(ctx context.Context, tx pgx.Tx, op DbOperation) error { +func (s *SchedulerDb) WriteDbOp(ctx *armadacontext.Context, tx pgx.Tx, op DbOperation) error { queries := schedulerdb.New(tx) switch o := op.(type) { case InsertJobs: @@ -274,7 +274,7 @@ func (s *SchedulerDb) WriteDbOp(ctx context.Context, tx pgx.Tx, op DbOperation) return nil } -func execBatch(ctx context.Context, tx pgx.Tx, batch *pgx.Batch) error { +func execBatch(ctx *armadacontext.Context, tx pgx.Tx, batch *pgx.Batch) error { result := tx.SendBatch(ctx, batch) for i := 0; i < batch.Len(); i++ { _, err := result.Exec() diff --git a/internal/scheduleringester/schedulerdb_test.go b/internal/scheduleringester/schedulerdb_test.go index 8317e421aff..873885c369e 100644 --- a/internal/scheduleringester/schedulerdb_test.go +++ b/internal/scheduleringester/schedulerdb_test.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "testing" "time" @@ -14,6 +13,7 @@ import ( "golang.org/x/exp/constraints" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/util" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" @@ -312,7 +312,7 @@ func addDefaultValues(op DbOperation) DbOperation { } func assertOpSuccess(t *testing.T, schedulerDb *SchedulerDb, serials map[string]int64, op DbOperation) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() // Apply the op to the database. @@ -329,7 +329,7 @@ func assertOpSuccess(t *testing.T, schedulerDb *SchedulerDb, serials map[string] // Read back the state from the db to compare. queries := schedulerdb.New(schedulerDb.db) - selectNewJobs := func(ctx context.Context, serial int64) ([]schedulerdb.Job, error) { + selectNewJobs := func(ctx *armadacontext.Context, serial int64) ([]schedulerdb.Job, error) { return queries.SelectNewJobs(ctx, schedulerdb.SelectNewJobsParams{Serial: serial, Limit: 1000}) } switch expected := op.(type) { @@ -645,7 +645,7 @@ func TestStore(t *testing.T) { runId: &JobRunDetails{queue: testQueueName, dbRun: &schedulerdb.Run{JobID: jobId, RunID: runId}}, }, } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() err := schedulerdb.WithTestDb(func(q *schedulerdb.Queries, db *pgxpool.Pool) error { schedulerDb := NewSchedulerDb(db, metrics.NewMetrics("test"), time.Second, time.Second, 10*time.Second)