From 908270758c92712a6e1e344414e023b642e83990 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 22:23:48 +0200 Subject: [PATCH 1/4] cmd,core,server: llm pipeline with stream support --- cmd/livepeer/starter/starter.go | 16 ++++ core/ai.go | 1 + core/capabilities.go | 2 + core/orchestrator.go | 9 ++ go.mod | 2 + server/ai_http.go | 69 +++++++++++++- server/ai_mediaserver.go | 69 ++++++++++++++ server/ai_process.go | 163 ++++++++++++++++++++++++++++++++ server/ai_process_test.go | 38 ++++++++ server/rpc.go | 1 + 10 files changed, 367 insertions(+), 3 deletions(-) create mode 100644 server/ai_process_test.go diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 99eae3558a..0ad4a614b1 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1318,6 +1318,22 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) } + n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) + + case "llm-generate": + _, ok := capabilityConstraints[core.Capability_LlmGenerate] + if !ok { + aiCaps = append(aiCaps, core.Capability_LlmGenerate) + capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint + + if *cfg.Network != "offchain" { + n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice) + } case "segment-anything-2": _, ok := capabilityConstraints[core.Capability_SegmentAnything2] if !ok { diff --git a/core/ai.go b/core/ai.go index 31f331e49e..1d22536d81 100644 --- a/core/ai.go +++ b/core/ai.go @@ -22,6 +22,7 @@ type AI interface { ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) + LlmGenerate(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error diff --git a/core/capabilities.go b/core/capabilities.go index 1ac674a9ed..4d74a3e1f0 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -78,6 +78,7 @@ const ( Capability_ImageToVideo Capability_Upscale Capability_AudioToText + Capability_LlmGenerate Capability_SegmentAnything2 ) @@ -115,6 +116,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", + Capability_LlmGenerate: "LLM Generate", Capability_SegmentAnything2: "Segment anything 2", } diff --git a/core/orchestrator.go b/core/orchestrator.go index e55405cf8c..285e5c0850 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -130,6 +130,11 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo return orch.node.AudioToText(ctx, req) } +// Return type is LlmResponse, but a stream is available as well as chan(string) +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + return orch.node.llmGenerate(ctx, req) +} + func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return orch.node.SegmentAnything2(ctx, req) } @@ -1062,6 +1067,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + return n.AIWorker.LlmGenerate(ctx, req) +} + func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { remoteChan, err := rtm.getTaskChan(tcID) if err != nil { diff --git a/go.mod b/go.mod index faa5c4eea2..e2cf00ff84 100644 --- a/go.mod +++ b/go.mod @@ -238,3 +238,5 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker diff --git a/server/ai_http.go b/server/ai_http.go index 3f0bb97d9e..4e6498e39e 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo())) lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) + lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate())) lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) return nil @@ -181,6 +182,29 @@ func (h *lphttp) SegmentAnything2() http.Handler { }) } +func (h *lphttp) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.GenLLMFormdataRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -305,6 +329,15 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels *= 1000 // Convert to milliseconds + case worker.GenLLMFormdataRequestBody: + pipeline = "llm-generate" + cap = core.Capability_LlmGenerate + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.LlmGenerate(ctx, v) + } + + // TODO: handle tokens for pricing case worker.GenSegmentAnything2MultipartRequestBody: pipeline = "segment-anything-2" cap = core.Capability_SegmentAnything2 @@ -407,7 +440,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit}) } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) + // Check if the response is a streaming response + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + for chunk := range streamChan { + data, err := json.Marshal(chunk) + if err != nil { + clog.Errorf(ctx, "Error marshaling stream chunk: %v", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + + if chunk.Done { + break + } + } + } else { + // Non-streaming response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 8190348aa4..93c85288f8 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "time" @@ -69,6 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) + ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate())) ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) return nil @@ -394,6 +396,73 @@ func (ls *LivepeerServer) AudioToText() http.Handler { }) } +func (ls *LivepeerServer) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + var req worker.LlmGenerateFormdataRequestBody + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processLlmGenerate(ctx, params, req) + if err != nil { + var e *ServiceUnavailableError + if errors.As(err, &e) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) + + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Handle streaming response (SSE) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + for chunk := range streamChan { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + if chunk.Done { + break + } + } + } else if llmResp, ok := resp.(*worker.LlmResponse); ok { + // Handle non-streaming response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(llmResp) + } else { + http.Error(w, "Unexpected response type", http.StatusInternalServerError) + } + }) +} + func (ls *LivepeerServer) SegmentAnything2() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index a244b82cc6..8a9a38bbc5 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -31,6 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt" const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" +const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct" const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" type ServiceUnavailableError struct { @@ -812,6 +813,159 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return &res, nil } +func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 { + if tokensUsed <= 0 { + return 0 + } + + return took.Seconds() / float64(tokensUsed) +} + +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + if req.Stream != nil && *req.Stream { + streamChan, ok := resp.(chan worker.LlmStreamChunk) + if !ok { + return nil, errors.New("unexpected response type for streaming request") + } + return streamChan, nil + } + + llmResp, ok := resp.(*worker.LLMResponse) + if !ok { + return nil, errors.New("unexpected response type") + } + + return llmResp, nil +} + +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + var buf bytes.Buffer + mw, err := worker.NewLLMMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + // TODO: calculate payment + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, 0) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + if req.Stream != nil && *req.Stream { + return handleSSEStream(ctx, resp.Body, sess, req, start) + } + + return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) +} + +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { + streamChan := make(chan worker.LlmStreamChunk, 100) + go func() { + defer close(streamChan) + scanner := bufio.NewScanner(body) + var totalTokens int + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens} + break + } + var chunk worker.LlmStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err) + continue + } + totalTokens += chunk.TokensUsed + streamChan <- chunk + } + } + if err := scanner.Err(); err != nil { + clog.Errorf(ctx, "Error reading SSE stream: %v", err) + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + }() + + return streamChan, nil +} + +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) { + data, err := io.ReadAll(body) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + var res worker.LLMResponse + if err := json.Unmarshal(data, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return &res, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -863,6 +1017,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitAudioToText(ctx, params, sess, v) } + case worker.GenLLMFormdataRequestBody: + cap = core.Capability_LlmGenerate + modelID = defaultLlmGenerateModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitLlmGenerate(ctx, params, sess, v) + } case worker.GenSegmentAnything2MultipartRequestBody: cap = core.Capability_SegmentAnything2 modelID = defaultSegmentAnything2ModelID diff --git a/server/ai_process_test.go b/server/ai_process_test.go new file mode 100644 index 0000000000..e584637ef2 --- /dev/null +++ b/server/ai_process_test.go @@ -0,0 +1,38 @@ +package server + +import ( + "context" + "reflect" + "testing" + + "github.com/livepeer/ai-worker/worker" +) + +func Test_submitLlmGenerate(t *testing.T) { + type args struct { + ctx context.Context + params aiRequestParams + sess *AISession + req worker.LlmGenerateFormdataRequestBody + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/rpc.go b/server/rpc.go index 17bd8727f8..67435ef8cc 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -68,6 +68,7 @@ type Orchestrator interface { ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) + LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) } From 4a9a4bc328ac61e3baf020c01186c0306c3335ce Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 1 Aug 2024 04:14:53 +0200 Subject: [PATCH 2/4] add basic pricing based on max out tokens --- server/ai_http.go | 8 +++++++- server/ai_process.go | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/server/ai_http.go b/server/ai_http.go index 4e6498e39e..4a40ed3b0f 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -337,7 +337,13 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return orch.LlmGenerate(ctx, v) } - // TODO: handle tokens for pricing + if v.MaxTokens == nil { + respondWithError(w, "MaxTokens not specified", http.StatusBadRequest) + return + } + + // TODO: Improve pricing + outPixels = int64(*v.MaxTokens) case worker.GenSegmentAnything2MultipartRequestBody: pipeline = "segment-anything-2" cap = core.Capability_SegmentAnything2 diff --git a/server/ai_process.go b/server/ai_process.go index 8a9a38bbc5..e757cf228a 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -861,8 +861,12 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess return nil, err } - // TODO: calculate payment - setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, 0) + // TODO: Improve pricing + if req.MaxTokens == nil { + req.MaxTokens = new(int) + *req.MaxTokens = 256 + } + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens)) if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) From eca254f32d9a399d8f7f3008e0667f0d4dd797bc Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Mon, 5 Aug 2024 20:37:19 +0200 Subject: [PATCH 3/4] misc: update ai-worker dependency and its usage --- cmd/livepeer/starter/starter.go | 12 +++++------ core/ai.go | 2 +- core/capabilities.go | 4 ++-- core/orchestrator.go | 11 ++++------ go.mod | 4 +--- go.sum | 4 ++-- server/ai_http.go | 12 +++++------ server/ai_mediaserver.go | 14 ++++++------- server/ai_process.go | 37 +++++++++++++++++---------------- server/ai_process_test.go | 10 ++++----- server/rpc.go | 2 +- server/rpc_test.go | 6 ++++++ 12 files changed, 60 insertions(+), 58 deletions(-) diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 0ad4a614b1..2e9c1f3160 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1320,19 +1320,19 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { } n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) - case "llm-generate": - _, ok := capabilityConstraints[core.Capability_LlmGenerate] + case "llm": + _, ok := capabilityConstraints[core.Capability_LLM] if !ok { - aiCaps = append(aiCaps, core.Capability_LlmGenerate) - capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{ + aiCaps = append(aiCaps, core.Capability_LLM) + capabilityConstraints[core.Capability_LLM] = &core.CapabilityConstraints{ Models: make(map[string]*core.ModelConstraint), } } - capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint + capabilityConstraints[core.Capability_LLM].Models[config.ModelID] = modelConstraint if *cfg.Network != "offchain" { - n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice) + n.SetBasePriceForCap("default", core.Capability_LLM, config.ModelID, autoPrice) } case "segment-anything-2": _, ok := capabilityConstraints[core.Capability_SegmentAnything2] diff --git a/core/ai.go b/core/ai.go index 1d22536d81..26e38b3586 100644 --- a/core/ai.go +++ b/core/ai.go @@ -22,7 +22,7 @@ type AI interface { ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) - LlmGenerate(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error) + LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error diff --git a/core/capabilities.go b/core/capabilities.go index 4d74a3e1f0..1280e1727d 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -78,7 +78,7 @@ const ( Capability_ImageToVideo Capability_Upscale Capability_AudioToText - Capability_LlmGenerate + Capability_LLM Capability_SegmentAnything2 ) @@ -116,7 +116,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", - Capability_LlmGenerate: "LLM Generate", + Capability_LLM: "Large Language Model", Capability_SegmentAnything2: "Segment anything 2", } diff --git a/core/orchestrator.go b/core/orchestrator.go index 285e5c0850..7ad5dc0b3d 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -130,9 +130,10 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo return orch.node.AudioToText(ctx, req) } -// Return type is LlmResponse, but a stream is available as well as chan(string) -func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { - return orch.node.llmGenerate(ctx, req) +// Return type is LLMResponse, but a stream is available as well as chan(string) +func (orch *orchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + return orch.node.AIWorker.LLM(ctx, req) + } func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { @@ -1067,10 +1068,6 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } -func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { - return n.AIWorker.LlmGenerate(ctx, req) -} - func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { remoteChan, err := rtm.getTaskChan(tcID) if err != nil { diff --git a/go.mod b/go.mod index e2cf00ff84..0abeaf4c2b 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/golang/protobuf v1.5.4 github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/pcidb v1.0.0 - github.com/livepeer/ai-worker v0.6.0 + github.com/livepeer/ai-worker v0.7.0 github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2 @@ -238,5 +238,3 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) - -replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker diff --git a/go.sum b/go.sum index dd4a295379..da5fdfb72b 100644 --- a/go.sum +++ b/go.sum @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.6.0 h1:sGldUavfbTXPQDKc1a80/zgK8G1VdYRAxiuFTP0YyOU= -github.com/livepeer/ai-worker v0.6.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= +github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw= +github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= diff --git a/server/ai_http.go b/server/ai_http.go index 4a40ed3b0f..ac4e4a0122 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -44,7 +44,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo())) lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) - lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate())) + lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM())) lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) return nil @@ -182,7 +182,7 @@ func (h *lphttp) SegmentAnything2() http.Handler { }) } -func (h *lphttp) LlmGenerate() http.Handler { +func (h *lphttp) LLM() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { orch := h.orchestrator @@ -330,11 +330,11 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } outPixels *= 1000 // Convert to milliseconds case worker.GenLLMFormdataRequestBody: - pipeline = "llm-generate" - cap = core.Capability_LlmGenerate + pipeline = "llm" + cap = core.Capability_LLM modelID = *v.ModelId submitFn = func(ctx context.Context) (interface{}, error) { - return orch.LlmGenerate(ctx, v) + return orch.LLM(ctx, v) } if v.MaxTokens == nil { @@ -447,7 +447,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } // Check if the response is a streaming response - if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok { // Set headers for SSE w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 93c85288f8..d8bca64a5b 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -70,7 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) - ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate())) + ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM())) ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) return nil @@ -396,14 +396,14 @@ func (ls *LivepeerServer) AudioToText() http.Handler { }) } -func (ls *LivepeerServer) LlmGenerate() http.Handler { +func (ls *LivepeerServer) LLM() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.LlmGenerateFormdataRequestBody + var req worker.GenLLMFormdataRequestBody multiRdr, err := r.MultipartReader() if err != nil { @@ -416,7 +416,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { return } - clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream) params := aiRequestParams{ node: ls.LivepeerNode, @@ -425,7 +425,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { } start := time.Now() - resp, err := processLlmGenerate(ctx, params, req) + resp, err := processLLM(ctx, params, req) if err != nil { var e *ServiceUnavailableError if errors.As(err, &e) { @@ -437,7 +437,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { } took := time.Since(start) - clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) + clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { // Handle streaming response (SSE) @@ -453,7 +453,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { break } } - } else if llmResp, ok := resp.(*worker.LlmResponse); ok { + } else if llmResp, ok := resp.(*worker.LLMResponse); ok { // Handle non-streaming response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(llmResp) diff --git a/server/ai_process.go b/server/ai_process.go index e757cf228a..f39e321a73 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -31,7 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt" const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" -const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct" +const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct" const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" type ServiceUnavailableError struct { @@ -813,7 +813,7 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return &res, nil } -func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 { +func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 { if tokensUsed <= 0 { return 0 } @@ -821,7 +821,7 @@ func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float6 return took.Seconds() / float64(tokensUsed) } -func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) { +func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -843,12 +843,12 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker. return llmResp, nil } -func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) { +func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) { var buf bytes.Buffer mw, err := worker.NewLLMMultipartWriter(&buf, req) if err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil) } return nil, err } @@ -856,7 +856,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) if err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo) } return nil, err } @@ -869,7 +869,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens)) if err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo) } return nil, err } @@ -879,11 +879,10 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) if err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo) } return nil, err } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -901,6 +900,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r streamChan := make(chan worker.LlmStreamChunk, 100) go func() { defer close(streamChan) + defer body.Close() scanner := bufio.NewScanner(body) var totalTokens int for scanner.Scan() { @@ -925,14 +925,14 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r } took := time.Since(start) - sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens) + sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens) if monitor.Enabled { var pricePerAIUnit float64 if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) } - monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) } }() @@ -941,9 +941,10 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) { data, err := io.ReadAll(body) + defer body.Close() if err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo) } return nil, err } @@ -951,20 +952,20 @@ func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *A var res worker.LLMResponse if err := json.Unmarshal(data, &res); err != nil { if monitor.Enabled { - monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo) } return nil, err } took := time.Since(start) - sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed) + sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed) if monitor.Enabled { var pricePerAIUnit float64 if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) } - monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) } return &res, nil @@ -1022,13 +1023,13 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface return submitAudioToText(ctx, params, sess, v) } case worker.GenLLMFormdataRequestBody: - cap = core.Capability_LlmGenerate - modelID = defaultLlmGenerateModelID + cap = core.Capability_LLM + modelID = defaultLLMModelID if v.ModelId != nil { modelID = *v.ModelId } submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { - return submitLlmGenerate(ctx, params, sess, v) + return submitLLM(ctx, params, sess, v) } case worker.GenSegmentAnything2MultipartRequestBody: cap = core.Capability_SegmentAnything2 diff --git a/server/ai_process_test.go b/server/ai_process_test.go index e584637ef2..c64771c931 100644 --- a/server/ai_process_test.go +++ b/server/ai_process_test.go @@ -8,12 +8,12 @@ import ( "github.com/livepeer/ai-worker/worker" ) -func Test_submitLlmGenerate(t *testing.T) { +func Test_submitLLM(t *testing.T) { type args struct { ctx context.Context params aiRequestParams sess *AISession - req worker.LlmGenerateFormdataRequestBody + req worker.GenLLMFormdataRequestBody } tests := []struct { name string @@ -25,13 +25,13 @@ func Test_submitLlmGenerate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) + got, err := submitLLM(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) if (err != nil) != tt.wantErr { - t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("submitLLM() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want) + t.Errorf("submitLLM() = %v, want %v", got, tt.want) } }) } diff --git a/server/rpc.go b/server/rpc.go index 67435ef8cc..9ff11e35a3 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -68,7 +68,7 @@ type Orchestrator interface { ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) - LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) + LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) } diff --git a/server/rpc_test.go b/server/rpc_test.go index 6f710a319d..df2b4204bd 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -202,6 +202,9 @@ func (r *stubOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMul func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } +func (r *stubOrchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + return nil, nil +} func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return nil, nil } @@ -1388,6 +1391,9 @@ func (r *mockOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMul func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } +func (r *mockOrchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) { + return nil, nil +} func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return nil, nil } From 0be0966f1d1d5c78b2b366af30fd20cdde15f77b Mon Sep 17 00:00:00 2001 From: Nico Vergauwen Date: Tue, 1 Oct 2024 21:41:27 +0200 Subject: [PATCH 4/4] change capability description Co-authored-by: Rick Staa --- core/capabilities.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/capabilities.go b/core/capabilities.go index 1280e1727d..b723b43c28 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -116,7 +116,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", - Capability_LLM: "Large Language Model", + Capability_LLM: "Large language model", Capability_SegmentAnything2: "Segment anything 2", }