diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 00c23ec3db..6f3802b1ad 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1328,6 +1328,21 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice) } + + case "llm-generate": + _, ok := capabilityConstraints[core.Capability_LlmGenerate] + if !ok { + aiCaps = append(aiCaps, core.Capability_LlmGenerate) + capabilityConstraints[core.Capability_LlmGenerate] = &core.CapabilityConstraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint + + if *cfg.Network != "offchain" { + n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice) + } } if len(aiCaps) > 0 { diff --git a/core/ai.go b/core/ai.go index 31f331e49e..0b7d223419 100644 --- a/core/ai.go +++ b/core/ai.go @@ -23,6 +23,7 @@ type AI interface { Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + LlmGenerate(context.Context, worker.GenLlmFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/capabilities.go b/core/capabilities.go index fc9e5217ba..2956e1f08d 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -79,6 +79,7 @@ const ( Capability_Upscale Capability_AudioToText Capability_SegmentAnything2 + Capability_LlmGenerate ) var CapabilityNameLookup = map[Capability]string{ @@ -116,6 +117,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", Capability_SegmentAnything2: "Segment anything 2", + Capability_LlmGenerate: "LLM Generate", } var CapabilityTestLookup = map[Capability]CapabilityTest{ diff --git a/core/orchestrator.go b/core/orchestrator.go index f8e343ae32..aff5303968 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -134,6 +134,11 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe return orch.node.SegmentAnything2(ctx, req) } +// Return type is LlmResponse, but a stream is available as well as chan(string) +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { + return orch.node.llmGenerate(ctx, req) +} + func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error { if orch.node == nil || orch.node.Recipient == nil { return nil @@ -1051,6 +1056,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { + return n.AIWorker.LlmGenerate(ctx, req) +} + func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { remoteChan, err := rtm.getTaskChan(tcID) if err != nil { diff --git a/go.mod b/go.mod index f84f4a6b37..d11b85fab3 100644 --- a/go.mod +++ b/go.mod @@ -238,3 +238,5 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/livepeer/ai-worker => github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c diff --git a/go.sum b/go.sum index 8c7b8cd8cb..3173327fae 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDO github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40= github.com/VictoriaMetrics/fastcache v1.12.1/go.mod h1:tX04vaqcNoQeGLD+ra5pU5sWkuxnzWhEzLwhP9w653o= +github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c h1:P1cDtj2uFXuYa1A68NXcocGxvcLt7J/XbjYKVH4LUJ4= +github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -623,8 +625,6 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8= -github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= diff --git a/server/ai_http.go b/server/ai_http.go index 3f0bb97d9e..e814b9d92e 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -45,6 +45,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) + lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate())) return nil } @@ -181,6 +182,29 @@ func (h *lphttp) SegmentAnything2() http.Handler { }) } +func (h *lphttp) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.GenLlmFormdataRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -324,6 +348,21 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) + case worker.GenLlmFormdataRequestBody: + pipeline = "llm-generate" + cap = core.Capability_LlmGenerate + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.LlmGenerate(ctx, v) + } + + if v.MaxTokens == nil { + respondWithError(w, "MaxTokens not specified", http.StatusBadRequest) + return + } + + // TODO: Improve pricing + outPixels = int64(*v.MaxTokens) default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -407,7 +446,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit}) } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) + // Check if the response is a streaming response + if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok { + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + for chunk := range streamChan { + data, err := json.Marshal(chunk) + if err != nil { + clog.Errorf(ctx, "Error marshaling stream chunk: %v", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + + if chunk.Done { + break + } + } + } else { + // Non-streaming response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 078fa05ee9..e4c7d6176b 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "time" @@ -70,7 +71,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) - + ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate())) return nil } @@ -428,6 +429,78 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler { }) } +func (ls *LivepeerServer) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + var req worker.GenLlmFormdataRequestBody + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + streamResponse := false + if req.Stream != nil { + streamResponse = *req.Stream + } + + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processLlmGenerate(ctx, params, req) + if err != nil { + var e *ServiceUnavailableError + if errors.As(err, &e) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) + + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Handle streaming response (SSE) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + for chunk := range streamChan { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + if chunk.Done { + break + } + } + } else if llmResp, ok := resp.(*worker.LlmResponse); ok { + // Handle non-streaming response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(llmResp) + } else { + http.Error(w, "Unexpected response type", http.StatusInternalServerError) + } + }) +} + func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index 249cc506b6..af4704069a 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -32,6 +32,7 @@ const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-x const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" +const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct" type ServiceUnavailableError struct { err error @@ -792,6 +793,164 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return &res, nil } +func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 { + if tokensUsed <= 0 { + return 0 + } + + return took.Seconds() / float64(tokensUsed) +} + +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLlmFormdataRequestBody) (interface{}, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + if req.Stream != nil && *req.Stream { + streamChan, ok := resp.(chan worker.LlmStreamChunk) + if !ok { + return nil, errors.New("unexpected response type for streaming request") + } + return streamChan, nil + } + + llmResp, ok := resp.(*worker.LlmResponse) + if !ok { + return nil, errors.New("unexpected response type") + } + + return llmResp, nil +} + +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLlmFormdataRequestBody) (interface{}, error) { + var buf bytes.Buffer + mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + // TODO: Improve pricing + if req.MaxTokens == nil { + req.MaxTokens = new(int) + *req.MaxTokens = 256 + } + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.GenLlmWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + if req.Stream != nil && *req.Stream { + return handleSSEStream(ctx, resp.Body, sess, req, start) + } + + return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) +} + +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { + streamChan := make(chan worker.LlmStreamChunk, 100) + go func() { + defer close(streamChan) + defer body.Close() + scanner := bufio.NewScanner(body) + var totalTokens int + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens} + break + } + var chunk worker.LlmStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err) + continue + } + totalTokens += chunk.TokensUsed + streamChan <- chunk + } + } + if err := scanner.Err(); err != nil { + clog.Errorf(ctx, "Error reading SSE stream: %v", err) + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + }() + + return streamChan, nil +} + +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { + data, err := io.ReadAll(body) + defer body.Close() + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + var res worker.LlmResponse + if err := json.Unmarshal(data, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return &res, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -852,6 +1011,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitSegmentAnything2(ctx, params, sess, v) } + case worker.GenLlmFormdataRequestBody: + cap = core.Capability_LlmGenerate + modelID = defaultLlmGenerateModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitLlmGenerate(ctx, params, sess, v) + } default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/ai_process_test.go b/server/ai_process_test.go new file mode 100644 index 0000000000..e64382fb38 --- /dev/null +++ b/server/ai_process_test.go @@ -0,0 +1,38 @@ +package server + +import ( + "context" + "reflect" + "testing" + + "github.com/livepeer/ai-worker/worker" +) + +func Test_submitLlmGenerate(t *testing.T) { + type args struct { + ctx context.Context + params aiRequestParams + sess *AISession + req worker.LlmGenerateLlmGeneratePostFormdataRequestBody + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/rpc.go b/server/rpc.go index 6c1365ccd6..4732567701 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -69,6 +69,7 @@ type Orchestrator interface { Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance