Skip to content

Commit

Permalink
Refactor to remove the repetition (#3203)
Browse files Browse the repository at this point in the history
* Refactor to remove the repetition

* Fix debug logging
  • Loading branch information
mjh1 authored Oct 16, 2024
1 parent f2e1832 commit 4390579
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 227 deletions.
250 changes: 23 additions & 227 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,94 +64,52 @@ func startAIMediaServer(ls *LivepeerServer) error {

openapi3filter.RegisterBodyDecoder("image/png", openapi3filter.FileBodyDecoder)

ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(ls.TextToImage()))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(ls.ImageToImage()))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(ls.Upscale()))
ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToImageJSONRequestBody], processTextToImage)))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToImageMultipartRequestBody], processImageToImage)))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(handle(ls, multipartDecoder[worker.GenUpscaleMultipartRequestBody], processUpscale)))
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenAudioToTextMultipartRequestBody], processAudioToText)))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2)))

return nil
}

func (ls *LivepeerServer) TextToImage() http.Handler {
// Decoder for JSON requests
func jsonDecoder[T any](req *T, r *http.Request) error {
return json.NewDecoder(r.Body).Decode(req)
}

// Decoder for Multipart requests
func multipartDecoder[T any](req *T, r *http.Request) error {
multiRdr, err := r.MultipartReader()
if err != nil {
return err
}
return runtime.BindMultipart(req, *multiRdr)
}

func handle[I, O any](ls *LivepeerServer, decoderFunc func(*I, *http.Request) error, processorFunc func(context.Context, aiRequestParams, I) (O, error)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processTextToImage(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.Infof(ctx, "Processed TextToImage request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) ImageToImage() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
var req I
if err := decoderFunc(&req, r); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received ImageToImage request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processImageToImage(ctx, params, req)
resp, err := processorFunc(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
Expand All @@ -167,9 +125,6 @@ func (ls *LivepeerServer) ImageToImage() http.Handler {
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed ImageToImage request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
Expand Down Expand Up @@ -290,112 +245,6 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler {
})
}

func (ls *LivepeerServer) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received Upscale request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processUpscale(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed Upscale request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) AudioToText() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request audioSize=%v model_id=%v", req.Audio.FileSize(), *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processAudioToText(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed AudioToText request model_id=%v took=%v", *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down Expand Up @@ -463,59 +312,6 @@ func (ls *LivepeerServer) LLM() http.Handler {
})
}

func (ls *LivepeerServer) SegmentAnything2() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received SegmentAnything2 request; image_size=%v model_id=%v", req.Image.FileSize(), *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processSegmentAnything2(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed SegmentAnything2 request model_id=%v took=%v", *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
8 changes: 8 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitTextToImage(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
case worker.GenImageToImageMultipartRequestBody:
cap = core.Capability_ImageToImage
modelID = defaultImageToImageModelID
Expand All @@ -995,6 +996,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitImageToImage(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
case worker.GenImageToVideoMultipartRequestBody:
cap = core.Capability_ImageToVideo
modelID = defaultImageToVideoModelID
Expand All @@ -1013,6 +1015,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitUpscale(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
case worker.GenAudioToTextMultipartRequestBody:
cap = core.Capability_AudioToText
modelID = defaultAudioToTextModelID
Expand All @@ -1031,6 +1034,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLLM(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
case worker.GenSegmentAnything2MultipartRequestBody:
cap = core.Capability_SegmentAnything2
modelID = defaultSegmentAnything2ModelID
Expand All @@ -1046,6 +1050,10 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
capName := cap.String()
ctx = clog.AddVal(ctx, "capability", capName)

clog.V(common.VERBOSE).Infof(ctx, "Received AI request model_id=%s", modelID)
start := time.Now()
defer clog.Infof(ctx, "Processed AI request model_id=%v took=%v", modelID, time.Since(start))

var resp interface{}

cctx, cancel := context.WithTimeout(ctx, processingRetryTimeout)
Expand Down

0 comments on commit 4390579

Please sign in to comment.