Skip to content

Commit

Permalink
ArmadaContext that includes a logger (#2934)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* compilation!

* rename package

* more compilation

* rename to Context

* embed

* compilation

* compilation

* fix test

* remove old ctxloggers

* revert design doc

* revert developer doc

* formatting

* wip

* tests

* don't gen

* don't gen

* merged master

---------

Co-authored-by: Chris Martin <chris@cmartinit.co.uk>
Co-authored-by: Albin Severinson <albin@severinson.org>
  • Loading branch information
3 people committed Sep 11, 2023
1 parent 7cdf653 commit d616feb
Show file tree
Hide file tree
Showing 178 changed files with 1,331 additions and 1,418 deletions.
5 changes: 2 additions & 3 deletions cmd/armada/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"context"
"fmt"
"net/http"
_ "net/http/pprof"
Expand All @@ -13,11 +12,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"golang.org/x/sync/errgroup"

"github.com/armadaproject/armada/internal/armada"
"github.com/armadaproject/armada/internal/armada/configuration"
"github.com/armadaproject/armada/internal/common"
"github.com/armadaproject/armada/internal/common/armadacontext"
gateway "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/logging"
Expand Down Expand Up @@ -67,7 +66,7 @@ func main() {
}

// Run services within an errgroup to propagate errors between services.
g, ctx := errgroup.WithContext(context.Background())
g, ctx := armadacontext.ErrGroup(armadacontext.Background())

// Cancel the errgroup context on SIGINT and SIGTERM,
// which shuts everything down gracefully.
Expand Down
8 changes: 4 additions & 4 deletions cmd/eventsprinter/logic/logic.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package logic

import (
"context"
"fmt"
"time"

"github.com/apache/pulsar-client-go/pulsar"
"github.com/gogo/protobuf/proto"
v1 "k8s.io/api/core/v1"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/util"
"github.com/armadaproject/armada/pkg/armadaevents"
)
Expand All @@ -18,7 +18,7 @@ func PrintEvents(url, topic, subscription string, verbose bool) error {
fmt.Println("URL:", url)
fmt.Println("Topic:", topic)
fmt.Println("Subscription", subscription)
return withSetup(url, topic, subscription, func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error {
return withSetup(url, topic, subscription, func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error {
// Number of active jobs.
numJobs := 0

Expand Down Expand Up @@ -199,7 +199,7 @@ func stripPodSpec(spec *v1.PodSpec) *v1.PodSpec {
}

// Run action with an Armada submit client and a Pulsar producer and consumer.
func withSetup(url, topic, subscription string, action func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error {
func withSetup(url, topic, subscription string, action func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error {
pulsarClient, err := pulsar.NewClient(pulsar.ClientOptions{
URL: url,
})
Expand All @@ -225,5 +225,5 @@ func withSetup(url, topic, subscription string, action func(ctx context.Context,
}
defer consumer.Close()

return action(context.Background(), producer, consumer)
return action(armadacontext.Background(), producer, consumer)
}
4 changes: 2 additions & 2 deletions cmd/executor/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"context"
"net/http"
"os"
"os/signal"
Expand All @@ -13,6 +12,7 @@ import (
"github.com/spf13/viper"

"github.com/armadaproject/armada/internal/common"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/executor"
"github.com/armadaproject/armada/internal/executor/configuration"
Expand Down Expand Up @@ -55,7 +55,7 @@ func main() {
)
defer shutdownMetricServer()

shutdown, wg := executor.StartUp(context.Background(), logrus.NewEntry(logrus.New()), config)
shutdown, wg := executor.StartUp(armadacontext.Background(), logrus.NewEntry(logrus.New()), config)
go func() {
<-shutdownChannel
shutdown()
Expand Down
14 changes: 7 additions & 7 deletions cmd/lookoutv2/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"context"
"os"
"os/signal"
"syscall"
Expand All @@ -12,6 +11,7 @@ import (
"k8s.io/apimachinery/pkg/util/clock"

"github.com/armadaproject/armada/internal/common"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/database"
"github.com/armadaproject/armada/internal/lookoutv2"
"github.com/armadaproject/armada/internal/lookoutv2/configuration"
Expand All @@ -36,9 +36,9 @@ func init() {
pflag.Parse()
}

func makeContext() (context.Context, func()) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
func makeContext() (*armadacontext.Context, func()) {
ctx := armadacontext.Background()
ctx, cancel := armadacontext.WithCancel(ctx)

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
Expand All @@ -57,7 +57,7 @@ func makeContext() (context.Context, func()) {
}
}

func migrate(ctx context.Context, config configuration.LookoutV2Configuration) {
func migrate(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) {
db, err := database.OpenPgxPool(config.Postgres)
if err != nil {
panic(err)
Expand All @@ -74,7 +74,7 @@ func migrate(ctx context.Context, config configuration.LookoutV2Configuration) {
}
}

func prune(ctx context.Context, config configuration.LookoutV2Configuration) {
func prune(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) {
db, err := database.OpenPgxConn(config.Postgres)
if err != nil {
panic(err)
Expand All @@ -92,7 +92,7 @@ func prune(ctx context.Context, config configuration.LookoutV2Configuration) {
log.Infof("expireAfter: %v, batchSize: %v, timeout: %v",
config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, config.PrunerConfig.Timeout)

ctxTimeout, cancel := context.WithTimeout(ctx, config.PrunerConfig.Timeout)
ctxTimeout, cancel := armadacontext.WithTimeout(ctx, config.PrunerConfig.Timeout)
defer cancel()
err = pruner.PruneDb(ctxTimeout, db, config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, clock.RealClock{})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions cmd/scheduler/cmd/migrate_database.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package cmd

import (
"context"
"time"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/database"
schedulerdb "github.com/armadaproject/armada/internal/scheduler/database"
)
Expand Down Expand Up @@ -43,7 +43,7 @@ func migrateDatabase(cmd *cobra.Command, _ []string) error {
return errors.WithMessagef(err, "Failed to connect to database")
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout)
defer cancel()
return schedulerdb.Migrate(ctx, db)
}
9 changes: 5 additions & 4 deletions internal/armada/repository/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/armadaproject/armada/internal/armada/repository/apimessages"
"github.com/armadaproject/armada/internal/armada/repository/sequence"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/compress"
"github.com/armadaproject/armada/pkg/api"
"github.com/armadaproject/armada/pkg/armadaevents"
Expand Down Expand Up @@ -48,7 +49,7 @@ func NewEventRepository(db redis.UniversalClient) *RedisEventRepository {
NumTestsPerEvictionRun: 10,
}

decompressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple(
decompressorPool := pool.NewObjectPool(armadacontext.Background(), pool.NewPooledObjectFactorySimple(
func(context.Context) (interface{}, error) {
return compress.NewZlibDecompressor(), nil
}), &poolConfig)
Expand Down Expand Up @@ -134,16 +135,16 @@ func (repo *RedisEventRepository) GetLastMessageId(queue, jobSetId string) (stri
func (repo *RedisEventRepository) extractEvents(msg redis.XMessage, queue, jobSetId string) ([]*api.EventMessage, error) {
data := msg.Values[dataKey]
bytes := []byte(data.(string))
decompressor, err := repo.decompressorPool.BorrowObject(context.Background())
decompressor, err := repo.decompressorPool.BorrowObject(armadacontext.Background())
if err != nil {
return nil, errors.WithStack(err)
}
defer func(decompressorPool *pool.ObjectPool, ctx context.Context, object interface{}) {
defer func(decompressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) {
err := decompressorPool.ReturnObject(ctx, object)
if err != nil {
log.WithError(err).Errorf("Error returning decompressor to pool")
}
}(repo.decompressorPool, context.Background(), decompressor)
}(repo.decompressorPool, armadacontext.Background(), decompressor)
decompressedData, err := decompressor.(compress.Decompressor).Decompress(bytes)
if err != nil {
return nil, errors.WithStack(err)
Expand Down
9 changes: 4 additions & 5 deletions internal/armada/repository/event_store.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
package repository

import (
"context"

"github.com/apache/pulsar-client-go/pulsar"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/eventutil"
"github.com/armadaproject/armada/internal/common/pulsarutils"
"github.com/armadaproject/armada/internal/common/schedulers"
"github.com/armadaproject/armada/pkg/api"
)

type EventStore interface {
ReportEvents(context.Context, []*api.EventMessage) error
ReportEvents(*armadacontext.Context, []*api.EventMessage) error
}

type TestEventStore struct {
ReceivedEvents []*api.EventMessage
}

func (es *TestEventStore) ReportEvents(_ context.Context, message []*api.EventMessage) error {
func (es *TestEventStore) ReportEvents(_ *armadacontext.Context, message []*api.EventMessage) error {
es.ReceivedEvents = append(es.ReceivedEvents, message...)
return nil
}
Expand All @@ -35,7 +34,7 @@ func NewEventStore(producer pulsar.Producer, maxAllowedMessageSize uint) *Stream
}
}

func (n *StreamEventStore) ReportEvents(ctx context.Context, apiEvents []*api.EventMessage) error {
func (n *StreamEventStore) ReportEvents(ctx *armadacontext.Context, apiEvents []*api.EventMessage) error {
if len(apiEvents) == 0 {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/armada/scheduling/lease_manager.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package scheduling

import (
"context"
"time"

log "github.com/sirupsen/logrus"

"github.com/armadaproject/armada/internal/armada/repository"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/pkg/api"
)

Expand Down Expand Up @@ -55,7 +55,7 @@ func (l *LeaseManager) ExpireLeases() {
if e != nil {
log.Error(e)
} else {
e := l.eventStore.ReportEvents(context.Background(), []*api.EventMessage{event})
e := l.eventStore.ReportEvents(armadacontext.Background(), []*api.EventMessage{event})
if e != nil {
log.Error(e)
}
Expand Down
9 changes: 4 additions & 5 deletions internal/armada/server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package armada

import (
"context"
"fmt"
"net"
"time"
Expand All @@ -13,7 +12,6 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

"github.com/armadaproject/armada/internal/armada/cache"
Expand All @@ -22,6 +20,7 @@ import (
"github.com/armadaproject/armada/internal/armada/repository"
"github.com/armadaproject/armada/internal/armada/scheduling"
"github.com/armadaproject/armada/internal/armada/server"
"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/database"
Expand All @@ -39,7 +38,7 @@ import (
"github.com/armadaproject/armada/pkg/client"
)

func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error {
func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error {
log.Info("Armada server starting")
log.Infof("Armada priority classes: %v", config.Scheduling.Preemption.PriorityClasses)
log.Infof("Default priority class: %s", config.Scheduling.Preemption.DefaultPriorityClass)
Expand All @@ -51,9 +50,9 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks

// Run all services within an errgroup to propagate errors between services.
// Defer cancelling the parent context to ensure the errgroup is cancelled on return.
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := armadacontext.WithCancel(ctx)
defer cancel()
g, ctx := errgroup.WithContext(ctx)
g, ctx := armadacontext.ErrGroup(ctx)

// List of services to run concurrently.
// Because we want to start services only once all input validation has been completed,
Expand Down
6 changes: 3 additions & 3 deletions internal/armada/server/authorization.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package server

import (
"context"
"fmt"
"strings"

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/auth/permission"
"github.com/armadaproject/armada/pkg/client/queue"
Expand Down Expand Up @@ -60,7 +60,7 @@ func MergePermissionErrors(errs ...*ErrUnauthorized) *ErrUnauthorized {
// permissions required to perform some action. The error returned is of type ErrUnauthorized.
// After recovering the error (using errors.As), the caller can obtain the name of the user and the
// requested permission programatically via this error type.
func checkPermission(p authorization.PermissionChecker, ctx context.Context, permission permission.Permission) error {
func checkPermission(p authorization.PermissionChecker, ctx *armadacontext.Context, permission permission.Permission) error {
if !p.UserHasPermission(ctx, permission) {
return &ErrUnauthorized{
Principal: authorization.GetPrincipal(ctx),
Expand All @@ -74,7 +74,7 @@ func checkPermission(p authorization.PermissionChecker, ctx context.Context, per

func checkQueuePermission(
p authorization.PermissionChecker,
ctx context.Context,
ctx *armadacontext.Context,
q queue.Queue,
globalPermission permission.Permission,
verb queue.PermissionVerb,
Expand Down
Loading

0 comments on commit d616feb

Please sign in to comment.