diff --git a/core/services/llo/data_source.go b/core/services/llo/data_source.go index 2afe9e090a3..2a99cb3bd1b 100644 --- a/core/services/llo/data_source.go +++ b/core/services/llo/data_source.go @@ -5,12 +5,12 @@ import ( "fmt" "slices" "sort" + "strconv" "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/shopspring/decimal" "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/streams" - "github.com/smartcontractkit/chainlink/v2/core/utils" ) var ( @@ -42,7 +41,7 @@ var ( ) type Registry interface { - Get(streamID streams.StreamID) (strm streams.Stream, exists bool) + Get(streamID streams.StreamID) (p streams.Pipeline, exists bool) } type ErrObservationFailed struct { @@ -109,42 +108,45 @@ func (d *dataSource) Observe(ctx context.Context, streamValues llo.StreamValues, successfulStreamIDs := make([]streams.StreamID, 0, len(streamValues)) var errs []ErrObservationFailed + // oc only lives for the duration of this Observe call + oc := NewObservationContext(d.registry, d.t) + for _, streamID := range maps.Keys(streamValues) { go func(streamID llotypes.StreamID) { defer wg.Done() - - var val llo.StreamValue - - stream, exists := d.registry.Get(streamID) - if !exists { - mu.Lock() - errs = append(errs, ErrObservationFailed{streamID: streamID, reason: fmt.Sprintf("missing stream: %d", streamID)}) - mu.Unlock() - promMissingStreamCount.WithLabelValues(fmt.Sprintf("%d", streamID)).Inc() - return - } - run, trrs, err := stream.Run(ctx) - if err != nil { - mu.Lock() - errs = append(errs, ErrObservationFailed{inner: err, run: run, streamID: streamID, reason: "pipeline run failed"}) - mu.Unlock() - promObservationErrorCount.WithLabelValues(fmt.Sprintf("%d", streamID)).Inc() - // TODO: Consolidate/reduce telemetry. We should send all observation results in a single packet - // https://smartcontract-it.atlassian.net/browse/MERC-6290 - d.t.EnqueueV3PremiumLegacy(run, trrs, streamID, opts, nil, err) - return - } - // TODO: Consolidate/reduce telemetry. We should send all observation results in a single packet - // https://smartcontract-it.atlassian.net/browse/MERC-6290 - val, err = ExtractStreamValue(trrs) + val, err := oc.Observe(ctx, streamID, opts) if err != nil { + promObservationErrorCount.WithLabelValues(strconv.FormatUint(uint64(streamID), 10)).Inc() mu.Lock() - errs = append(errs, ErrObservationFailed{inner: err, run: run, streamID: streamID, reason: "failed to extract big.Int"}) + errs = append(errs, ErrObservationFailed{inner: err, streamID: streamID, reason: "failed to observe stream"}) mu.Unlock() return } - d.t.EnqueueV3PremiumLegacy(run, trrs, streamID, opts, val, nil) + // TODO: check telem/prom + // var val llo.StreamValue + + // stream, exists := d.registry.Get(streamID) + // if !exists { + // mu.Lock() + // errs = append(errs, ErrObservationFailed{streamID: streamID, reason: fmt.Sprintf("missing stream: %d", streamID)}) + // mu.Unlock() + // promMissingStreamCount.WithLabelValues(fmt.Sprintf("%d", streamID)).Inc() + // return + // } + // run, trrs, err := stream.Run(ctx) + // if err != nil { + // mu.Lock() + // errs = append(errs, ErrObservationFailed{inner: err, run: run, streamID: streamID, reason: "pipeline run failed"}) + // mu.Unlock() + // promObservationErrorCount.WithLabelValues(fmt.Sprintf("%d", streamID)).Inc() + // // TODO: Consolidate/reduce telemetry. We should send all observation results in a single packet + // // https://smartcontract-it.atlassian.net/browse/MERC-6290 + // d.t.EnqueueV3PremiumLegacy(run, trrs, streamID, opts, nil, err) + // return + // } + // // TODO: Consolidate/reduce telemetry. We should send all observation results in a single packet + // // https://smartcontract-it.atlassian.net/browse/MERC-6290 mu.Lock() defer mu.Unlock() @@ -186,54 +188,3 @@ func (d *dataSource) Observe(ctx context.Context, streamValues llo.StreamValues, return nil } - -// ExtractStreamValue extracts a StreamValue from a TaskRunResults -func ExtractStreamValue(trrs pipeline.TaskRunResults) (llo.StreamValue, error) { - // pipeline.TaskRunResults comes ordered asc by index, this is guaranteed - // by the pipeline executor - finaltrrs := trrs.Terminals() - - // HACK: Right now we rely on the number of outputs to determine whether - // its a Decimal or a Quote. - // This isn't very robust or future-proof but is sufficient to support v0.3 - // compat. - // There are a number of different possible ways to solve this in future. - // See: https://smartcontract-it.atlassian.net/browse/MERC-5934 - switch len(finaltrrs) { - case 1: - res := finaltrrs[0].Result - if res.Error != nil { - return nil, res.Error - } - val, err := toDecimal(res.Value) - if err != nil { - return nil, fmt.Errorf("failed to parse BenchmarkPrice: %w", err) - } - return llo.ToDecimal(val), nil - case 3: - // Expect ordering of Benchmark, Bid, Ask - results := make([]decimal.Decimal, 3) - for i, trr := range finaltrrs { - res := trr.Result - if res.Error != nil { - return nil, fmt.Errorf("failed to parse stream output into Quote (task index: %d): %w", i, res.Error) - } - val, err := toDecimal(res.Value) - if err != nil { - return nil, fmt.Errorf("failed to parse decimal: %w", err) - } - results[i] = val - } - return &llo.Quote{ - Benchmark: results[0], - Bid: results[1], - Ask: results[2], - }, nil - default: - return nil, fmt.Errorf("invalid number of results, expected: 1 or 3, got: %d", len(finaltrrs)) - } -} - -func toDecimal(val interface{}) (decimal.Decimal, error) { - return utils.ToDecimal(val) -} diff --git a/core/services/llo/data_source_test.go b/core/services/llo/data_source_test.go index 932c4c0c73a..87370d4442a 100644 --- a/core/services/llo/data_source_test.go +++ b/core/services/llo/data_source_test.go @@ -21,27 +21,36 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/streams" ) -type mockStream struct { +type mockPipeline struct { run *pipeline.Run trrs pipeline.TaskRunResults err error + + streamIDs []streams.StreamID + + runCount int } -func (m *mockStream) Run(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) { +func (m *mockPipeline) Run(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) { + m.runCount++ return m.run, m.trrs, m.err } +func (m *mockPipeline) StreamIDs() []streams.StreamID { + return m.streamIDs +} + type mockRegistry struct { - streams map[streams.StreamID]*mockStream + pipelines map[streams.StreamID]*mockPipeline } -func (m *mockRegistry) Get(streamID streams.StreamID) (strm streams.Stream, exists bool) { - strm, exists = m.streams[streamID] +func (m *mockRegistry) Get(streamID streams.StreamID) (p streams.Pipeline, exists bool) { + p, exists = m.pipelines[streamID] return } -func makeStreamWithSingleResult[T any](runID int64, res T, err error) *mockStream { - return &mockStream{ +func makePipelineWithSingleResult[T any](runID int64, res T, err error) *mockPipeline { + return &mockPipeline{ run: &pipeline.Run{ID: runID}, trrs: []pipeline.TaskRunResult{pipeline.TaskRunResult{Task: &pipeline.MemoTask{}, Result: pipeline.Result{Value: res}}}, err: err, @@ -91,7 +100,7 @@ func (m *mockTelemeter) EnqueueV3PremiumLegacy(run *pipeline.Run, trrs pipeline. func Test_DataSource(t *testing.T) { lggr := logger.TestLogger(t) - reg := &mockRegistry{make(map[streams.StreamID]*mockStream)} + reg := &mockRegistry{make(map[streams.StreamID]*mockPipeline)} ds := newDataSource(lggr, reg, NullTelemeter) ctx := testutils.Context(t) opts := &mockOpts{} @@ -105,9 +114,9 @@ func Test_DataSource(t *testing.T) { assert.Equal(t, makeStreamValues(), vals) }) t.Run("observes each stream with success and returns values matching map argument", func(t *testing.T) { - reg.streams[1] = makeStreamWithSingleResult[*big.Int](1, big.NewInt(2181), nil) - reg.streams[2] = makeStreamWithSingleResult[*big.Int](2, big.NewInt(40602), nil) - reg.streams[3] = makeStreamWithSingleResult[*big.Int](3, big.NewInt(15), nil) + reg.pipelines[1] = makePipelineWithSingleResult[*big.Int](1, big.NewInt(2181), nil) + reg.pipelines[2] = makePipelineWithSingleResult[*big.Int](2, big.NewInt(40602), nil) + reg.pipelines[3] = makePipelineWithSingleResult[*big.Int](3, big.NewInt(15), nil) vals := makeStreamValues() err := ds.Observe(ctx, vals, opts) @@ -120,9 +129,9 @@ func Test_DataSource(t *testing.T) { }, vals) }) t.Run("observes each stream and returns success/errors", func(t *testing.T) { - reg.streams[1] = makeStreamWithSingleResult[*big.Int](1, big.NewInt(2181), errors.New("something exploded")) - reg.streams[2] = makeStreamWithSingleResult[*big.Int](2, big.NewInt(40602), nil) - reg.streams[3] = makeStreamWithSingleResult[*big.Int](3, nil, errors.New("something exploded 2")) + reg.pipelines[1] = makePipelineWithSingleResult[*big.Int](1, big.NewInt(2181), errors.New("something exploded")) + reg.pipelines[2] = makePipelineWithSingleResult[*big.Int](2, big.NewInt(40602), nil) + reg.pipelines[3] = makePipelineWithSingleResult[*big.Int](3, nil, errors.New("something exploded 2")) vals := makeStreamValues() err := ds.Observe(ctx, vals, opts) @@ -139,9 +148,9 @@ func Test_DataSource(t *testing.T) { tm := &mockTelemeter{} ds.t = tm - reg.streams[1] = makeStreamWithSingleResult[*big.Int](100, big.NewInt(2181), nil) - reg.streams[2] = makeStreamWithSingleResult[*big.Int](101, big.NewInt(40602), nil) - reg.streams[3] = makeStreamWithSingleResult[*big.Int](102, big.NewInt(15), nil) + reg.pipelines[1] = makePipelineWithSingleResult[*big.Int](100, big.NewInt(2181), nil) + reg.pipelines[2] = makePipelineWithSingleResult[*big.Int](101, big.NewInt(40602), nil) + reg.pipelines[3] = makePipelineWithSingleResult[*big.Int](102, big.NewInt(15), nil) vals := makeStreamValues() err := ds.Observe(ctx, vals, opts) @@ -166,5 +175,37 @@ func Test_DataSource(t *testing.T) { assert.Equal(t, "2181", pkt.val.(*llo.Decimal).String()) assert.Nil(t, pkt.err) }) + + t.Run("records telemetry for errors", func(t *testing.T) { + tm := &mockTelemeter{} + ds.t = tm + + reg.pipelines[1] = makePipelineWithSingleResult[*big.Int](100, big.NewInt(2181), errors.New("something exploded")) + reg.pipelines[2] = makePipelineWithSingleResult[*big.Int](101, big.NewInt(40602), nil) + reg.pipelines[3] = makePipelineWithSingleResult[*big.Int](102, nil, errors.New("something exploded 2")) + + vals := makeStreamValues() + err := ds.Observe(ctx, vals, opts) + assert.NoError(t, err) + + assert.Equal(t, llo.StreamValues{ + 2: llo.ToDecimal(decimal.NewFromInt(40602)), + 1: nil, + 3: nil, + }, vals) + + require.Len(t, tm.v3PremiumLegacyPackets, 3) + m := make(map[int]v3PremiumLegacyPacket) + for _, pkt := range tm.v3PremiumLegacyPackets { + m[int(pkt.run.ID)] = pkt + } + pkt := m[100] + assert.Equal(t, 100, int(pkt.run.ID)) + assert.Len(t, pkt.trrs, 1) + assert.Equal(t, 1, int(pkt.streamID)) + assert.Equal(t, opts, pkt.opts) + assert.Nil(t, pkt.val) + assert.NotNil(t, pkt.err) + }) }) } diff --git a/core/services/llo/observation_context.go b/core/services/llo/observation_context.go new file mode 100644 index 00000000000..45554bfb565 --- /dev/null +++ b/core/services/llo/observation_context.go @@ -0,0 +1,181 @@ +package llo + +import ( + "context" + "fmt" + "sync" + + "github.com/shopspring/decimal" + + "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + "github.com/smartcontractkit/chainlink/v2/core/services/streams" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +// ObservationContext ensures that each pipeline is only executed once. It is +// intended to be instantiated and used then discarded as part of one +// Observation cycle. Subsequent calls to Observe will return the same cached +// values. + +var _ ObservationContext = (*observationContext)(nil) + +type ObservationContext interface { + Observe(ctx context.Context, streamID streams.StreamID, opts llo.DSOpts) (val llo.StreamValue, err error) +} + +type execution struct { + done <-chan struct{} + + run *pipeline.Run + trrs pipeline.TaskRunResults + err error +} + +type observationContext struct { + r Registry + t Telemeter + + executionsMu sync.Mutex + // only execute each pipeline once + executions map[streams.Pipeline]*execution +} + +func NewObservationContext(r Registry, t Telemeter) ObservationContext { + return newObservationContext(r, t) +} + +func newObservationContext(r Registry, t Telemeter) *observationContext { + return &observationContext{r, t, sync.Mutex{}, make(map[streams.Pipeline]*execution)} +} + +func (oc *observationContext) Observe(ctx context.Context, streamID streams.StreamID, opts llo.DSOpts) (val llo.StreamValue, err error) { + run, trrs, err := oc.run(ctx, streamID) + if err != nil { + // FIXME: This is a hack specific for V3 telemetry, future schemas should + // use a generic stream value telemetry instead + // https://smartcontract-it.atlassian.net/browse/MERC-6290 + oc.t.EnqueueV3PremiumLegacy(run, trrs, streamID, opts, val, err) + return nil, err + } + // Extract stream value based on streamID attribute + for _, trr := range trrs { + if trr.Task.TaskStreamID() != nil && *trr.Task.TaskStreamID() == streamID { + return resultToStreamValue(trr.Result.Value) + } + } + // If no streamID attribute is found in the task results, then assume the + // final output is the stream ID and return that. This is safe to do since + // the registry will never return a spec that doesn't match either by tag + // or by spec streamID. + + val, err = extractFinalResultAsStreamValue(trrs) + // FIXME: This is a hack specific for V3 telemetry, future schemas should + // use a generic stream value telemetry instead + // https://smartcontract-it.atlassian.net/browse/MERC-6290 + oc.t.EnqueueV3PremiumLegacy(run, trrs, streamID, opts, val, err) + return +} + +func resultToStreamValue(val interface{}) (llo.StreamValue, error) { + switch v := val.(type) { + case decimal.Decimal: + return llo.ToDecimal(v), nil + default: + return nil, fmt.Errorf("don't know how to convert pipeline output result of type %T to llo.StreamValue", val) + } +} + +// extractFinalResultAsStreamValue extracts a final StreamValue from a TaskRunResults +func extractFinalResultAsStreamValue(trrs pipeline.TaskRunResults) (llo.StreamValue, error) { + // pipeline.TaskRunResults comes ordered asc by index, this is guaranteed + // by the pipeline executor + finaltrrs := trrs.Terminals() + + // HACK: Right now we rely on the number of outputs to determine whether + // its a Decimal or a Quote. + // This is a hack to support the legacy "Quote" case. + // Future stream specs should use streamID tags instead. + switch len(finaltrrs) { + case 1: + res := finaltrrs[0].Result + if res.Error != nil { + return nil, res.Error + } + val, err := toDecimal(res.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse BenchmarkPrice: %w", err) + } + return llo.ToDecimal(val), nil + case 3: + // Expect ordering of Benchmark, Bid, Ask + results := make([]decimal.Decimal, 3) + for i, trr := range finaltrrs { + res := trr.Result + if res.Error != nil { + return nil, fmt.Errorf("failed to parse stream output into Quote (task index: %d): %w", i, res.Error) + } + val, err := toDecimal(res.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse decimal: %w", err) + } + results[i] = val + } + return &llo.Quote{ + Benchmark: results[0], + Bid: results[1], + Ask: results[2], + }, nil + default: + return nil, fmt.Errorf("invalid number of results, expected: 1 or 3, got: %d", len(finaltrrs)) + } +} + +func toDecimal(val interface{}) (decimal.Decimal, error) { + return utils.ToDecimal(val) +} + +type ErrMissingStream struct { + StreamID streams.StreamID +} + +func (e ErrMissingStream) Error() string { + return fmt.Sprintf("no pipeline for stream: %d", e.StreamID) +} + +func (oc *observationContext) run(ctx context.Context, streamID streams.StreamID) (*pipeline.Run, pipeline.TaskRunResults, error) { + strm, exists := oc.r.Get(streamID) + if !exists { + // promMissingStreamCount.WithLabelValues(fmt.Sprintf("%d", streamID)).Inc() + return nil, nil, ErrMissingStream{StreamID: streamID} + } + + // In case of multiple streamIDs per stream (FIXME: naming??) then the + // first call executes and the others wait for result + oc.executionsMu.Lock() + ex, isExecuting := oc.executions[strm] + if isExecuting { + oc.executionsMu.Unlock() + // wait for it to finish + select { + case <-ex.done: + return ex.run, ex.trrs, ex.err + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + } + + // execute here + ch := make(chan struct{}) + ex = &execution{done: ch} + oc.executions[strm] = ex + oc.executionsMu.Unlock() + + run, trrs, err := strm.Run(ctx) + ex.run = run + ex.trrs = trrs + ex.err = err + close(ch) + + return run, trrs, err +} diff --git a/core/services/llo/observation_context_test.go b/core/services/llo/observation_context_test.go new file mode 100644 index 00000000000..958da0351cc --- /dev/null +++ b/core/services/llo/observation_context_test.go @@ -0,0 +1,107 @@ +package llo + +import ( + "errors" + "math/rand/v2" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink/v2/core/null" + "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + "github.com/smartcontractkit/chainlink/v2/core/services/streams" +) + +func makeErroringPipeline() *mockPipeline { + return &mockPipeline{ + err: errors.New("pipeline error"), + } +} + +func makePipelineWithMultipleStreamResults(streamIDs []streams.StreamID, results []interface{}) *mockPipeline { + if len(streamIDs) != len(results) { + panic("streamIDs and results must have the same length") + } + trrs := make([]pipeline.TaskRunResult, len(streamIDs)) + for i, res := range results { + trrs[i] = pipeline.TaskRunResult{Task: &pipeline.MemoTask{BaseTask: pipeline.BaseTask{StreamID: null.Uint32From(streamIDs[i])}}, Result: pipeline.Result{Value: res}} + } + return &mockPipeline{ + run: &pipeline.Run{}, + trrs: trrs, + err: nil, + streamIDs: streamIDs, + } +} + +func TestObservationContext_Observe(t *testing.T) { + ctx := tests.Context(t) + r := &mockRegistry{} + telem := &mockTelemeter{} + oc := newObservationContext(r, telem) + opts := llo.DSOpts(nil) + + missingStreamID := streams.StreamID(0) + streamID1 := streams.StreamID(1) + streamID2 := streams.StreamID(2) + streamID3 := streams.StreamID(3) + streamID4 := streams.StreamID(4) + streamID5 := streams.StreamID(5) + streamID6 := streams.StreamID(6) + + multiPipelineDecimal := makePipelineWithMultipleStreamResults([]streams.StreamID{streamID4, streamID5, streamID6}, []interface{}{decimal.NewFromFloat(12.34), decimal.NewFromFloat(56.78), decimal.NewFromFloat(90.12)}) + + r.pipelines = map[streams.StreamID]*mockPipeline{ + streamID1: &mockPipeline{}, + streamID2: makePipelineWithSingleResult[decimal.Decimal](rand.Int64(), decimal.NewFromFloat(12.34), nil), + streamID3: makeErroringPipeline(), + streamID4: multiPipelineDecimal, + streamID5: multiPipelineDecimal, + streamID6: multiPipelineDecimal, + } + + t.Run("returns error in case of missing pipeline", func(t *testing.T) { + _, err := oc.Observe(ctx, missingStreamID, opts) + require.EqualError(t, err, "no pipeline for stream: 0") + }) + t.Run("returns error in case of zero results", func(t *testing.T) { + _, err := oc.Observe(ctx, streamID1, opts) + require.EqualError(t, err, "invalid number of results, expected: 1 or 3, got: 0") + }) + t.Run("returns composite value from legacy job with single top-level streamID", func(t *testing.T) { + val, err := oc.Observe(ctx, streamID2, opts) + require.NoError(t, err) + + assert.Equal(t, "12.34", val.(*llo.Decimal).String()) + }) + t.Run("returns error in case of erroring pipeline", func(t *testing.T) { + _, err := oc.Observe(ctx, streamID3, opts) + require.EqualError(t, err, "pipeline error") + }) + t.Run("returns values for multiple stream IDs within the same job based on streamID tag with a single pipeline execution", func(t *testing.T) { + val, err := oc.Observe(ctx, streamID4, opts) + require.NoError(t, err) + assert.Equal(t, "12.34", val.(*llo.Decimal).String()) + + val, err = oc.Observe(ctx, streamID5, opts) + require.NoError(t, err) + assert.Equal(t, "56.78", val.(*llo.Decimal).String()) + + val, err = oc.Observe(ctx, streamID6, opts) + require.NoError(t, err) + assert.Equal(t, "90.12", val.(*llo.Decimal).String()) + + assert.Equal(t, 1, multiPipelineDecimal.runCount) + + // returns cached values on subsequent calls + val, err = oc.Observe(ctx, streamID6, opts) + require.NoError(t, err) + assert.Equal(t, "90.12", val.(*llo.Decimal).String()) + + assert.Equal(t, 1, multiPipelineDecimal.runCount) + }) +} diff --git a/core/services/pipeline/common.go b/core/services/pipeline/common.go index 50611ee32a4..56af199078a 100644 --- a/core/services/pipeline/common.go +++ b/core/services/pipeline/common.go @@ -61,6 +61,7 @@ type ( TaskMinBackoff() time.Duration TaskMaxBackoff() time.Duration TaskTags() string + TaskStreamID() *uint32 GetDescendantTasks() []Task } diff --git a/core/services/pipeline/task.base.go b/core/services/pipeline/task.base.go index 3e1db5fcdb5..fdedb69193e 100644 --- a/core/services/pipeline/task.base.go +++ b/core/services/pipeline/task.base.go @@ -24,6 +24,8 @@ type BaseTask struct { Tags string `mapstructure:"tags" json:"-"` + StreamID null.Uint32 `mapstructure:"streamID"` + uuid uuid.UUID } @@ -84,6 +86,13 @@ func (t BaseTask) TaskTags() string { return t.Tags } +func (t BaseTask) TaskStreamID() *uint32 { + if t.StreamID.Valid { + return &t.StreamID.Uint32 + } + return nil +} + // GetDescendantTasks retrieves all descendant tasks of a given task func (t BaseTask) GetDescendantTasks() []Task { if len(t.outputs) == 0 { diff --git a/core/services/streams/delegate.go b/core/services/streams/delegate.go index bf492d4bd15..2f62a7bf1f4 100644 --- a/core/services/streams/delegate.go +++ b/core/services/streams/delegate.go @@ -52,8 +52,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] rrs := ocrcommon.NewResultRunSaver(d.runner, lggr, d.cfg.MaxSuccessfulRuns(), d.cfg.ResultWriteQueueDepth()) services = append(services, rrs, &StreamService{ d.registry, - id, - jb.PipelineSpec, + jb, lggr, rrs, }) @@ -66,23 +65,22 @@ type ResultRunSaver interface { type StreamService struct { registry Registry - id StreamID - spec *pipeline.Spec + jb job.Job lggr logger.Logger rrs ResultRunSaver } func (s *StreamService) Start(_ context.Context) error { - if s.spec == nil { - return fmt.Errorf("pipeline spec unexpectedly missing for stream %q", s.id) + if s.jb.PipelineSpec == nil { + return errors.New("pipeline spec unexpectedly missing for stream") } - s.lggr.Debugf("Starting stream %d", s.id) - return s.registry.Register(s.id, *s.spec, s.rrs) + s.lggr.Debugw("Registering stream", "jobID", s.jb.ID) + return s.registry.Register(s.jb, s.rrs) } func (s *StreamService) Close() error { - s.lggr.Debugf("Stopping stream %d", s.id) - s.registry.Unregister(s.id) + s.lggr.Debugw("Unregistering stream", "jobID", s.jb.ID) + s.registry.Unregister(s.jb.ID) return nil } @@ -101,8 +99,23 @@ func ValidatedStreamSpec(tomlString string) (job.Job, error) { return jb, errors.Errorf("unsupported type: %q", jb.Type) } - if jb.StreamID == nil { - return jb, errors.New("jobs of type 'stream' require streamID to be specified") + // The spec stream ID is optional, but if provided represents the final output of the pipeline run. + // nodes in the DAG may also contain streamID tags. + // Every spec must have at least one streamID. + var streamIDs []StreamID + + if jb.StreamID != nil { + streamIDs = append(streamIDs, *jb.StreamID) + } + + for _, t := range jb.Pipeline.Tasks { + if streamID := t.TaskStreamID(); streamID != nil { + streamIDs = append(streamIDs, *streamID) + } + } + + if len(streamIDs) == 0 { + return jb, errors.New("no streamID found in spec (must be either specified as top-level key 'streamID' or at least one streamID tag must be provided in the pipeline)") } return jb, nil diff --git a/core/services/streams/pipeline.go b/core/services/streams/pipeline.go new file mode 100644 index 00000000000..273f61e6add --- /dev/null +++ b/core/services/streams/pipeline.go @@ -0,0 +1,125 @@ +package streams + +import ( + "context" + "fmt" + "sync" + + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/job" + "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" +) + +type Runner interface { + ExecuteRun(ctx context.Context, spec pipeline.Spec, vars pipeline.Vars) (run *pipeline.Run, trrs pipeline.TaskRunResults, err error) + InitializePipeline(spec pipeline.Spec) (*pipeline.Pipeline, error) +} + +type RunResultSaver interface { + Save(run *pipeline.Run) +} + +// TODO: Rename this one to MultiStreamPipeline? +type Pipeline interface { + Run(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) + StreamIDs() []StreamID +} + +type multiStreamPipeline struct { + sync.RWMutex + lggr logger.Logger + spec pipeline.Spec + runner Runner + rrs RunResultSaver + streamIDs []StreamID +} + +func NewMultiStreamPipeline(lggr logger.Logger, jb job.Job, runner Runner, rrs RunResultSaver) (Pipeline, error) { + return newMultiStreamPipeline(lggr, jb, runner, rrs) +} + +func newMultiStreamPipeline(lggr logger.Logger, jb job.Job, runner Runner, rrs RunResultSaver) (*multiStreamPipeline, error) { + if jb.PipelineSpec == nil { + // should never happen + return nil, fmt.Errorf("job has no pipeline spec") + } + spec := *jb.PipelineSpec + if spec.Pipeline == nil { + pipeline, err := spec.ParsePipeline() + if err != nil { + return nil, fmt.Errorf("unparseable pipeline: %w", err) + } + + spec.Pipeline = pipeline + // initialize it for the given runner + if _, err := runner.InitializePipeline(spec); err != nil { + return nil, fmt.Errorf("error while initializing pipeline: %w", err) + } + } + var streamIDs []StreamID + for _, t := range spec.Pipeline.Tasks { + if t.TaskStreamID() != nil { + streamIDs = append(streamIDs, *t.TaskStreamID()) + } + } + if jb.StreamID != nil { + streamIDs = append(streamIDs, *jb.StreamID) + } + return &multiStreamPipeline{sync.RWMutex{}, lggr.Named("MultiStreamPipeline").With("spec.ID", spec.ID, "jobID", spec.JobID, "jobName", spec.JobName, "jobType", spec.JobType), spec, runner, rrs, streamIDs}, nil +} + +func (s *multiStreamPipeline) Run(ctx context.Context) (run *pipeline.Run, trrs pipeline.TaskRunResults, err error) { + run, trrs, err = s.executeRun(ctx) + + if err != nil { + return nil, nil, fmt.Errorf("Run failed: %w", err) + } + if s.rrs != nil { + s.rrs.Save(run) + } + + return +} + +func (s *multiStreamPipeline) StreamIDs() []StreamID { + return s.streamIDs +} + +// The context passed in here has a timeout of (ObservationTimeout + ObservationGracePeriod). +// Upon context cancellation, its expected that we return any usable values within ObservationGracePeriod. +func (s *multiStreamPipeline) executeRun(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) { + // the hot path here is to avoid parsing and use the pre-parsed, cached, pipeline + s.RLock() + // TODO: move this up to new + initialize := s.spec.Pipeline == nil + s.RUnlock() + if initialize { + pipeline, err := s.spec.ParsePipeline() + if err != nil { + return nil, nil, fmt.Errorf("Run failed due to unparseable pipeline: %w", err) + } + + s.Lock() + if s.spec.Pipeline == nil { + s.spec.Pipeline = pipeline + // initialize it for the given runner + if _, err := s.runner.InitializePipeline(s.spec); err != nil { + return nil, nil, fmt.Errorf("Run failed due to error while initializing pipeline: %w", err) + } + } + s.Unlock() + } + + vars := pipeline.NewVarsFrom(map[string]interface{}{ + "pipelineSpec": map[string]interface{}{ + "id": s.spec.ID, + }, + }) + + run, trrs, err := s.runner.ExecuteRun(ctx, s.spec, vars) + if err != nil { + return nil, nil, fmt.Errorf("error executing run for spec ID %v: %w", s.spec.ID, err) + } + + return run, trrs, err +} diff --git a/core/services/streams/stream.go b/core/services/streams/stream.go deleted file mode 100644 index b65c6dc12f6..00000000000 --- a/core/services/streams/stream.go +++ /dev/null @@ -1,94 +0,0 @@ -package streams - -import ( - "context" - "fmt" - "sync" - - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" -) - -type Runner interface { - ExecuteRun(ctx context.Context, spec pipeline.Spec, vars pipeline.Vars) (run *pipeline.Run, trrs pipeline.TaskRunResults, err error) - InitializePipeline(spec pipeline.Spec) (*pipeline.Pipeline, error) -} - -type RunResultSaver interface { - Save(run *pipeline.Run) -} - -type Stream interface { - Run(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) -} - -type stream struct { - sync.RWMutex - id StreamID - lggr logger.Logger - spec *pipeline.Spec - runner Runner - rrs RunResultSaver -} - -func NewStream(lggr logger.Logger, id StreamID, spec pipeline.Spec, runner Runner, rrs RunResultSaver) Stream { - return newStream(lggr, id, spec, runner, rrs) -} - -func newStream(lggr logger.Logger, id StreamID, spec pipeline.Spec, runner Runner, rrs RunResultSaver) *stream { - return &stream{sync.RWMutex{}, id, lggr.Named("Stream").With("streamID", id), &spec, runner, rrs} -} - -func (s *stream) Run(ctx context.Context) (run *pipeline.Run, trrs pipeline.TaskRunResults, err error) { - run, trrs, err = s.executeRun(ctx) - - if err != nil { - return nil, nil, fmt.Errorf("Run failed: %w", err) - } - if s.rrs != nil { - s.rrs.Save(run) - } - - return -} - -// The context passed in here has a timeout of (ObservationTimeout + ObservationGracePeriod). -// Upon context cancellation, its expected that we return any usable values within ObservationGracePeriod. -func (s *stream) executeRun(ctx context.Context) (*pipeline.Run, pipeline.TaskRunResults, error) { - // the hot path here is to avoid parsing and use the pre-parsed, cached, pipeline - s.RLock() - initialize := s.spec.Pipeline == nil - s.RUnlock() - if initialize { - pipeline, err := s.spec.ParsePipeline() - if err != nil { - return nil, nil, fmt.Errorf("Run failed due to unparseable pipeline: %w", err) - } - - s.Lock() - if s.spec.Pipeline == nil { - s.spec.Pipeline = pipeline - // initialize it for the given runner - if _, err := s.runner.InitializePipeline(*s.spec); err != nil { - return nil, nil, fmt.Errorf("Run failed due to error while initializing pipeline: %w", err) - } - } - s.Unlock() - } - - vars := pipeline.NewVarsFrom(map[string]interface{}{ - "pipelineSpec": map[string]interface{}{ - "id": s.spec.ID, - }, - "stream": map[string]interface{}{ - "id": s.id, - }, - }) - - run, trrs, err := s.runner.ExecuteRun(ctx, *s.spec, vars) - if err != nil { - return nil, nil, fmt.Errorf("error executing run for spec ID %v: %w", s.spec.ID, err) - } - - return run, trrs, err -} diff --git a/core/services/streams/stream_registry.go b/core/services/streams/stream_registry.go index 9ab2df11d33..3fadd0bac12 100644 --- a/core/services/streams/stream_registry.go +++ b/core/services/streams/stream_registry.go @@ -7,27 +7,32 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/llo" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + "github.com/smartcontractkit/chainlink/v2/core/services/job" ) +// TODO: Rename, this is actually a PipelineRegistry (? is it ?) + // alias for easier refactoring type StreamID = llo.StreamID type Registry interface { Getter - Register(streamID StreamID, spec pipeline.Spec, rrs ResultRunSaver) error - Unregister(streamID StreamID) + Register(jb job.Job, rrs ResultRunSaver) error + Unregister(jobID int32) } type Getter interface { - Get(streamID StreamID) (strm Stream, exists bool) + Get(streamID StreamID) (p Pipeline, exists bool) } type streamRegistry struct { sync.RWMutex - lggr logger.Logger - runner Runner - streams map[StreamID]Stream + lggr logger.Logger + runner Runner + // keyed by stream ID + pipelines map[StreamID]Pipeline + // keyed by job ID + pipelinesByJobID map[int32]Pipeline } func NewRegistry(lggr logger.Logger, runner Runner) Registry { @@ -39,29 +44,52 @@ func newRegistry(lggr logger.Logger, runner Runner) *streamRegistry { sync.RWMutex{}, lggr.Named("Registry"), runner, - make(map[StreamID]Stream), + make(map[StreamID]Pipeline), + make(map[int32]Pipeline), } } -func (s *streamRegistry) Get(streamID StreamID) (strm Stream, exists bool) { +func (s *streamRegistry) Get(streamID StreamID) (p Pipeline, exists bool) { s.RLock() defer s.RUnlock() - strm, exists = s.streams[streamID] + p, exists = s.pipelines[streamID] return } -func (s *streamRegistry) Register(streamID StreamID, spec pipeline.Spec, rrs ResultRunSaver) error { +func (s *streamRegistry) Register(jb job.Job, rrs ResultRunSaver) error { + if jb.Type != job.Stream { + return fmt.Errorf("cannot register job type %s; only Stream jobs are supported", jb.Type) + } s.Lock() defer s.Unlock() - if _, exists := s.streams[streamID]; exists { - return fmt.Errorf("stream already registered for id: %d", streamID) + if _, exists := s.pipelinesByJobID[jb.ID]; exists { + return fmt.Errorf("cannot register job with ID: %d; it is already registered", jb.ID) + } + p, err := NewMultiStreamPipeline(s.lggr, jb, s.runner, rrs) + if err != nil { + return fmt.Errorf("cannot register job with ID: %d; %w", jb.ID, err) + } + s.pipelinesByJobID[jb.ID] = p + // FIXME: Naming is so awkward, call it a Multistream or something instead? Or combistream? + streamIDs := p.StreamIDs() + for _, strmID := range streamIDs { + if _, exists := s.pipelines[strmID]; exists { + return fmt.Errorf("cannot register job with ID: %d; stream id %d is already registered", jb.ID, strmID) + } + s.pipelines[strmID] = p } - s.streams[streamID] = NewStream(s.lggr, streamID, spec, s.runner, rrs) return nil } -func (s *streamRegistry) Unregister(streamID StreamID) { +func (s *streamRegistry) Unregister(jobID int32) { s.Lock() defer s.Unlock() - delete(s.streams, streamID) + p, exists := s.pipelinesByJobID[jobID] + if !exists { + return + } + streamIDs := p.StreamIDs() + for _, id := range streamIDs { + delete(s.pipelines, id) + } }