From 2095440e424dd5f8668f88a54b091f6105786d44 Mon Sep 17 00:00:00 2001 From: Brad P Date: Fri, 20 Sep 2024 22:07:04 -0500 Subject: [PATCH] update for new codegen and update to segfault fix --- core/ai.go | 2 +- core/orchestrator.go | 4 ++-- server/ai_http.go | 4 ++-- server/ai_mediaserver.go | 5 +++-- server/ai_process.go | 12 ++++++------ server/rpc.go | 2 +- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/ai.go b/core/ai.go index ebc7e3dfc..0b7d22341 100644 --- a/core/ai.go +++ b/core/ai.go @@ -23,7 +23,7 @@ type AI interface { Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(context.Context, worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) + LlmGenerate(context.Context, worker.GenLlmFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/orchestrator.go b/core/orchestrator.go index e3685fc7f..aff530396 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -135,7 +135,7 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe } // Return type is LlmResponse, but a stream is available as well as chan(string) -func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { return orch.node.llmGenerate(ctx, req) } @@ -1056,7 +1056,7 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } -func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { return n.AIWorker.LlmGenerate(ctx, req) } diff --git a/server/ai_http.go b/server/ai_http.go index 7238e38cd..e814b9d92 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -195,7 +195,7 @@ func (h *lphttp) LlmGenerate() http.Handler { return } - var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody + var req worker.GenLlmFormdataRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -348,7 +348,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) - case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: + case worker.GenLlmFormdataRequestBody: pipeline = "llm-generate" cap = core.Capability_LlmGenerate modelID = *v.ModelId diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 88cb233ba..e4c7d6176 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -436,7 +436,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody + var req worker.GenLlmFormdataRequestBody multiRdr, err := r.MultipartReader() if err != nil { @@ -450,9 +450,10 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { } streamResponse := false - if *req.Stream { + if req.Stream != nil { streamResponse = *req.Stream } + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse) params := aiRequestParams{ diff --git a/server/ai_process.go b/server/ai_process.go index 8404ed401..af4704069 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -801,7 +801,7 @@ func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float6 return took.Seconds() / float64(tokensUsed) } -func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLlmFormdataRequestBody) (interface{}, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -823,7 +823,7 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker. return llmResp, nil } -func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLlmFormdataRequestBody) (interface{}, error) { var buf bytes.Buffer mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) if err != nil { @@ -856,7 +856,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.LlmGenerateLlmGeneratePostWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.GenLlmWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) @@ -876,7 +876,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) } -func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { streamChan := make(chan worker.LlmStreamChunk, 100) go func() { defer close(streamChan) @@ -919,7 +919,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r return streamChan, nil } -func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { data, err := io.ReadAll(body) defer body.Close() if err != nil { @@ -1011,7 +1011,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitSegmentAnything2(ctx, params, sess, v) } - case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: + case worker.GenLlmFormdataRequestBody: cap = core.Capability_LlmGenerate modelID = defaultLlmGenerateModelID if v.ModelId != nil { diff --git a/server/rpc.go b/server/rpc.go index 8eccffc32..473256770 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -69,7 +69,7 @@ type Orchestrator interface { Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) + LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance