Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ai): remove code redundancy in AI orchestrator server #3226

Open
wants to merge 1 commit into
base: ai-video
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 14 additions & 174 deletions server/ai_http.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package server

// ai_http.go implements the HTTP server for AI-related requests at the Orchestrator.

import (
"bufio"
"context"
Expand All @@ -23,7 +25,6 @@
"github.com/livepeer/go-livepeer/core"
"github.com/livepeer/go-livepeer/monitor"
middleware "github.com/oapi-codegen/nethttp-middleware"
"github.com/oapi-codegen/runtime"
)

var MaxAIRequestSize = 3000000000 // 3GB
Expand All @@ -48,193 +49,32 @@

openapi3filter.RegisterBodyDecoder("image/png", openapi3filter.FileBodyDecoder)

lp.transRPC.Handle("/text-to-image", oapiReqValidator(lp.TextToImage()))
lp.transRPC.Handle("/image-to-image", oapiReqValidator(lp.ImageToImage()))
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(lp.ImageToText()))
lp.transRPC.Handle("/text-to-image", oapiReqValidator(aiHttpHandle(&lp, jsonDecoder[worker.GenTextToImageJSONRequestBody])))
lp.transRPC.Handle("/image-to-image", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenImageToImageMultipartRequestBody])))
lp.transRPC.Handle("/image-to-video", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenImageToVideoMultipartRequestBody])))
lp.transRPC.Handle("/upscale", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenUpscaleMultipartRequestBody])))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenAudioToTextMultipartRequestBody])))
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(&lp, jsonDecoder[worker.GenLLMFormdataRequestBody])))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody])))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(&lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody])))

Check warning on line 59 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L52-L59

Added lines #L52 - L59 were not covered by tests
// Additionally, there is the '/aiResults' endpoint registered in server/rpc.go

return nil
}

func (h *lphttp) TextToImage() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) ImageToImage() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) ImageToVideo() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenImageToVideoMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) AudioToText() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) SegmentAnything2() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenLLMFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func (h *lphttp) ImageToText() http.Handler {
// aiHttpHandle handles AI requests by decoding the request body and processing it.
func aiHttpHandle[I any](h *lphttp, decoderFunc func(*I, *http.Request) error) http.Handler {

Check warning on line 66 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L66

Added line #L66 was not covered by tests
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
var req I
if err := decoderFunc(&req, r); err != nil {

Check warning on line 73 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L72-L73

Added lines #L72 - L73 were not covered by tests
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenImageToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}
Expand Down
28 changes: 7 additions & 21 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,20 @@ func startAIMediaServer(ls *LivepeerServer) error {

openapi3filter.RegisterBodyDecoder("image/png", openapi3filter.FileBodyDecoder)

ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToImageJSONRequestBody], processTextToImage)))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToImageMultipartRequestBody], processImageToImage)))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(handle(ls, multipartDecoder[worker.GenUpscaleMultipartRequestBody], processUpscale)))
ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(aiMediaServerHandle(ls, jsonDecoder[worker.GenTextToImageJSONRequestBody], processTextToImage)))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenImageToImageMultipartRequestBody], processImageToImage)))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenUpscaleMultipartRequestBody], processUpscale)))
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenAudioToTextMultipartRequestBody], processAudioToText)))
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenAudioToTextMultipartRequestBody], processAudioToText)))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2)))
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText)))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2)))
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText)))

return nil
}

// Decoder for JSON requests
func jsonDecoder[T any](req *T, r *http.Request) error {
return json.NewDecoder(r.Body).Decode(req)
}

// Decoder for Multipart requests
func multipartDecoder[T any](req *T, r *http.Request) error {
multiRdr, err := r.MultipartReader()
if err != nil {
return err
}
return runtime.BindMultipart(req, *multiRdr)
}

func handle[I, O any](ls *LivepeerServer, decoderFunc func(*I, *http.Request) error, processorFunc func(context.Context, aiRequestParams, I) (O, error)) http.Handler {
func aiMediaServerHandle[I, O any](ls *LivepeerServer, decoderFunc func(*I, *http.Request) error, processorFunc func(context.Context, aiRequestParams, I) (O, error)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
Expand Down
24 changes: 24 additions & 0 deletions server/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package server

// utils.go contains server utility functions.

import (
"encoding/json"
"net/http"

"github.com/oapi-codegen/runtime"
)

// Decoder for JSON requests.
func jsonDecoder[T any](req *T, r *http.Request) error {
return json.NewDecoder(r.Body).Decode(req)

Check warning on line 14 in server/utils.go

View check run for this annotation

Codecov / codecov/patch

server/utils.go#L13-L14

Added lines #L13 - L14 were not covered by tests
}

// Decoder for Multipart requests.
func multipartDecoder[T any](req *T, r *http.Request) error {
multiRdr, err := r.MultipartReader()
if err != nil {
return err
}
return runtime.BindMultipart(req, *multiRdr)

Check warning on line 23 in server/utils.go

View check run for this annotation

Codecov / codecov/patch

server/utils.go#L18-L23

Added lines #L18 - L23 were not covered by tests
}
Loading