Skip to content

Commit

Permalink
update for new codegen and update to segfault fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ad-astra-video committed Sep 21, 2024
1 parent 9edafa2 commit 2095440
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type AI interface {
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
LlmGenerate(context.Context, worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error)
LlmGenerate(context.Context, worker.GenLlmFormdataRequestBody) (interface{}, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
4 changes: 2 additions & 2 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe
}

// Return type is LlmResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) {
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
return orch.node.llmGenerate(ctx, req)
}

Expand Down Expand Up @@ -1056,7 +1056,7 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) {
func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
return n.AIWorker.LlmGenerate(ctx, req)
}

Expand Down
4 changes: 2 additions & 2 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (h *lphttp) LlmGenerate() http.Handler {
return
}

var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody
var req worker.GenLlmFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -348,7 +348,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.LlmGenerateLlmGeneratePostFormdataRequestBody:
case worker.GenLlmFormdataRequestBody:
pipeline = "llm-generate"
cap = core.Capability_LlmGenerate
modelID = *v.ModelId
Expand Down
5 changes: 3 additions & 2 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody
var req worker.GenLlmFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
Expand All @@ -450,9 +450,10 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
}

streamResponse := false
if *req.Stream {
if req.Stream != nil {
streamResponse = *req.Stream
}

clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse)

params := aiRequestParams{
Expand Down
12 changes: 6 additions & 6 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float6
return took.Seconds() / float64(tokensUsed)
}

func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) {
func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
Expand All @@ -823,7 +823,7 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.
return llmResp, nil
}

func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) {
func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
var buf bytes.Buffer
mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req)
if err != nil {
Expand Down Expand Up @@ -856,7 +856,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.LlmGenerateLlmGeneratePostWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
resp, err := client.GenLlmWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
Expand All @@ -876,7 +876,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
}

func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
streamChan := make(chan worker.LlmStreamChunk, 100)
go func() {
defer close(streamChan)
Expand Down Expand Up @@ -919,7 +919,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r
return streamChan, nil
}

func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) {
func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) {
data, err := io.ReadAll(body)
defer body.Close()
if err != nil {
Expand Down Expand Up @@ -1011,7 +1011,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitSegmentAnything2(ctx, params, sess, v)
}
case worker.LlmGenerateLlmGeneratePostFormdataRequestBody:
case worker.GenLlmFormdataRequestBody:
cap = core.Capability_LlmGenerate
modelID = defaultLlmGenerateModelID
if v.ModelId != nil {
Expand Down
2 changes: 1 addition & 1 deletion server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type Orchestrator interface {
Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error)
LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down

0 comments on commit 2095440

Please sign in to comment.