Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM pipeline with stream support #3114

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,22 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)
}
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "llm":
_, ok := capabilityConstraints[core.Capability_LLM]
if !ok {
aiCaps = append(aiCaps, core.Capability_LLM)
capabilityConstraints[core.Capability_LLM] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LLM].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LLM, config.ModelID, autoPrice)
}
case "segment-anything-2":
_, ok := capabilityConstraints[core.Capability_SegmentAnything2]
if !ok {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
Expand Down
2 changes: 2 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_LLM
Capability_SegmentAnything2
)

Expand Down Expand Up @@ -115,6 +116,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_LLM: "Large language model",
Capability_SegmentAnything2: "Segment anything 2",
}

Expand Down
6 changes: 6 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo
return orch.node.AudioToText(ctx, req)
}

// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.AIWorker.LLM(ctx, req)

}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.6.0
github.com/livepeer/ai-worker v0.7.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.6.0 h1:sGldUavfbTXPQDKc1a80/zgK8G1VdYRAxiuFTP0YyOU=
github.com/livepeer/ai-worker v0.6.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw=
github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
75 changes: 72 additions & 3 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -181,6 +182,29 @@ func (h *lphttp) SegmentAnything2() http.Handler {
})
}

func (h *lphttp) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenLLMFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -305,6 +329,21 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LLM(ctx, v)
}

if v.MaxTokens == nil {
respondWithError(w, "MaxTokens not specified", http.StatusBadRequest)
return
}

// TODO: Improve pricing
outPixels = int64(*v.MaxTokens)
case worker.GenSegmentAnything2MultipartRequestBody:
pipeline = "segment-anything-2"
cap = core.Capability_SegmentAnything2
Expand Down Expand Up @@ -407,7 +446,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit})
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
// Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

for chunk := range streamChan {
data, err := json.Marshal(chunk)
if err != nil {
clog.Errorf(ctx, "Error marshaling stream chunk: %v", err)
continue
}

fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
break
}
}
} else {
// Non-streaming response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}
69 changes: 69 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -69,6 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -394,6 +396,73 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
})
}

func (ls *LivepeerServer) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenLLMFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLLM(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

for chunk := range streamChan {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
break
}
}
} else if llmResp, ok := resp.(*worker.LLMResponse); ok {
// Handle non-streaming response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(llmResp)
} else {
http.Error(w, "Unexpected response type", http.StatusInternalServerError)
}
})
}

func (ls *LivepeerServer) SegmentAnything2() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
Loading
Loading