diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index e8c8e6470..72fd638de 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1239,6 +1239,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { Models: make(map[string]*core.ModelConstraint), } } + model, exists := capabilityConstraints[pipelineCap].Models[config.ModelID] if !exists { capabilityConstraints[pipelineCap].Models[config.ModelID] = modelConstraint diff --git a/common/util.go b/common/util.go index 83ed01964..dd6e73b87 100644 --- a/common/util.go +++ b/common/util.go @@ -84,6 +84,7 @@ var ( "video/mp2t": ".ts", "video/mp4": ".mp4", "image/png": ".png", + "audio/wav": ".wav", } ) diff --git a/core/ai.go b/core/ai.go index 0f7c0474f..871b99b44 100644 --- a/core/ai.go +++ b/core/ai.go @@ -26,6 +26,7 @@ type AI interface { LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) + TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/ai_test.go b/core/ai_test.go index 07b46bea6..dc924760e 100644 --- a/core/ai_test.go +++ b/core/ai_test.go @@ -659,6 +659,10 @@ func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTex return &worker.ImageToTextResponse{Text: "Transcribed text"}, nil } +func (a *stubAIWorker) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) { + return &worker.AudioResponse{Audio: worker.MediaURL{Url: "http://example.com/audio.wav"}}, nil +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { return nil } diff --git a/core/ai_worker.go b/core/ai_worker.go index 32b1343bc..7bb00f0ec 100644 --- a/core/ai_worker.go +++ b/core/ai_worker.go @@ -422,68 +422,95 @@ func (n *LivepeerNode) saveLocalAIWorkerResults(ctx context.Context, results int ext, _ := common.MimeTypeToExtension(contentType) fileName := string(RandomManifestID()) + ext - imgRes, ok := results.(worker.ImageResponse) - if !ok { - // worker.TextResponse is JSON, no file save needed - return results, nil - } storage, exists := n.StorageConfigs[requestID] if !exists { return nil, errors.New("no storage available for request") } + var buf bytes.Buffer - for i, image := range imgRes.Images { - buf.Reset() - err := worker.ReadImageB64DataUrl(image.Url, &buf) - if err != nil { - // try to load local file (image to video returns local file) - f, err := os.ReadFile(image.Url) + switch resp := results.(type) { + case worker.ImageResponse: + for i, image := range resp.Images { + buf.Reset() + err := worker.ReadImageB64DataUrl(image.Url, &buf) + if err != nil { + // try to load local file (image to video returns local file) + f, err := os.ReadFile(image.Url) + if err != nil { + return nil, err + } + buf = *bytes.NewBuffer(f) + } + + osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0) if err != nil { return nil, err } - buf = *bytes.NewBuffer(f) + + resp.Images[i].Url = osUrl + } + + results = resp + case worker.AudioResponse: + err := worker.ReadAudioB64DataUrl(resp.Audio.Url, &buf) + if err != nil { + return nil, err } osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0) if err != nil { return nil, err } + resp.Audio.Url = osUrl - imgRes.Images[i].Url = osUrl + results = resp } - return imgRes, nil + //no file response to save, response is text + return results, nil } func (n *LivepeerNode) saveRemoteAIWorkerResults(ctx context.Context, results *RemoteAIWorkerResult, requestID string) (*RemoteAIWorkerResult, error) { if drivers.NodeStorage == nil { return nil, fmt.Errorf("Missing local storage") } - + // save the file data to node and provide url for download + storage, exists := n.StorageConfigs[requestID] + if !exists { + return nil, errors.New("no storage available for request") + } // worker.ImageResponse used by ***-to-image and image-to-video require saving binary data for download + // worker.AudioResponse used to text-to-speech also requires saving binary data for download // other pipelines do not require saving data since they are text responses - imgResp, isImg := results.Results.(worker.ImageResponse) - if isImg { - for idx := range imgResp.Images { - fileName := imgResp.Images[idx].Url - // save the file data to node and provide url for download - storage, exists := n.StorageConfigs[requestID] - if !exists { - return nil, errors.New("no storage available for request") - } + switch resp := results.Results.(type) { + case worker.ImageResponse: + for idx := range resp.Images { + fileName := resp.Images[idx].Url osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0) if err != nil { return nil, err } - imgResp.Images[idx].Url = osUrl + resp.Images[idx].Url = osUrl delete(results.Files, fileName) } // update results for url updates - results.Results = imgResp + results.Results = resp + case worker.AudioResponse: + fileName := resp.Audio.Url + osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0) + if err != nil { + return nil, err + } + + resp.Audio.Url = osUrl + delete(results.Files, fileName) + + results.Results = resp } + // no file response to save, response is text return results, nil } @@ -789,6 +816,39 @@ func (orch *orchestrator) ImageToText(ctx context.Context, requestID string, req return res.Results, nil } +func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) { + // local AIWorker processes job if combined orchestrator/ai worker + if orch.node.AIWorker != nil { + workerResp, err := orch.node.TextToSpeech(ctx, req) + if err == nil { + return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "audio/wav") + } else { + clog.Errorf(ctx, "Error processing with local ai worker err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + } + + // remote ai worker proceses job + res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "text-to-speech", *req.ModelId, "", AIJobRequestData{Request: req}) + if err != nil { + return nil, err + } + + res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID) + if err != nil { + clog.Errorf(ctx, "Error saving remote ai result err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + + return res.Results, nil +} + // only used for sending work to remote AI worker func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) { node := orch.node @@ -959,6 +1019,10 @@ func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequest return n.AIWorker.LLM(ctx, req) } +func (n *LivepeerNode) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) { + return n.AIWorker.TextToSpeech(ctx, req) +} + // transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline. func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult { ctx = clog.AddOrchSessionID(ctx, sessionID) diff --git a/core/capabilities.go b/core/capabilities.go index 44b954ccb..c82ae843a 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -82,6 +82,7 @@ const ( Capability_LLM Capability = 33 Capability_ImageToText Capability = 34 Capability_LiveVideoToVideo Capability = 35 + Capability_TextToSpeech Capability = 36 ) var CapabilityNameLookup = map[Capability]string{ @@ -122,6 +123,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_LLM: "Llm", Capability_ImageToText: "Image to text", Capability_LiveVideoToVideo: "Live video to video", + Capability_TextToSpeech: "Text to speech", } var CapabilityTestLookup = map[Capability]CapabilityTest{ @@ -214,6 +216,7 @@ func OptionalCapabilities() []Capability { Capability_AudioToText, Capability_SegmentAnything2, Capability_ImageToText, + Capability_TextToSpeech, } } diff --git a/go.mod b/go.mod index cefba48d8..271f6675d 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/golang/protobuf v1.5.4 github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/pcidb v1.0.0 - github.com/livepeer/ai-worker v0.9.2 + github.com/livepeer/ai-worker v0.11.0 github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2 diff --git a/go.sum b/go.sum index 26de318a8..2ea3ae7ee 100644 --- a/go.sum +++ b/go.sum @@ -604,8 +604,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.9.2 h1:kgXb6sjfi93pJxxsAtWyAGo53/+gHsf7JMRVApor+zU= -github.com/livepeer/ai-worker v0.9.2/go.mod h1:/Deme7XXRP4BiYXt/j694Ygw+dh8rWJdikJsKY64sjE= +github.com/livepeer/ai-worker v0.11.0 h1:prbRRBgCIrECUuZFWuyN6z3QZLfygYqBKYMleT+I7o4= +github.com/livepeer/ai-worker v0.11.0/go.mod h1:GjQuPmz69UO53WVtqzB9Ygok5MmKCGNuobbfMXH7zgw= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= diff --git a/server/ai_http.go b/server/ai_http.go index ce7d8b76b..c26e42c3a 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "github.com/getkin/kin-openapi/openapi3filter" "github.com/golang/glog" @@ -57,6 +58,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) lp.transRPC.Handle("/image-to-text", oapiReqValidator(lp.ImageToText())) lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo())) + lp.transRPC.Handle("/text-to-speech", oapiReqValidator(lp.TextToSpeech())) // Additionally, there is the '/aiResults' endpoint registered in server/rpc.go return nil @@ -269,6 +271,23 @@ func (h *lphttp) StartLiveVideoToVideo() http.Handler { }) } +func (h *lphttp) TextToSpeech() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + var req worker.GenTextToSpeechJSONRequestBody + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -448,6 +467,18 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) + case worker.GenTextToSpeechJSONRequestBody: + pipeline = "text-to-speech" + cap = core.Capability_TextToSpeech + modelID = *v.ModelId + + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.TextToSpeech(ctx, requestID, v) + } + + // TTS pricing is typically in characters, including punctuation. + words := utf8.RuneCountInString(*v.Text) + outPixels = int64(1000 * words) default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -551,6 +582,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels) case worker.GenImageToTextMultipartRequestBody: latencyScore = CalculateImageToTextLatencyScore(took, outPixels) + case worker.GenTextToSpeechJSONRequestBody: + latencyScore = CalculateTextToSpeechLatencyScore(took, outPixels) } var pricePerAIUnit float64 @@ -753,13 +786,22 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core break } results = parsedResp - case "audio-to-text", "segment-anything-2", "llm": + case "audio-to-text", "segment-anything-2", "llm", "image-to-text": err := json.Unmarshal(body, &results) if err != nil { glog.Error("Error getting results json:", err) wkrResult.Err = err break } + case "text-to-speech": + var parsedResp worker.AudioResponse + err := json.Unmarshal(body, &parsedResp) + if err != nil { + glog.Error("Error getting results json:", err) + wkrResult.Err = err + break + } + results = parsedResp } wkrResult.Results = results diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index b80f9af31..e91f8d805 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -73,6 +73,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM())) ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2))) ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText))) + ls.HTTPMux.Handle("/text-to-speech", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToSpeechJSONRequestBody], processTextToSpeech))) // This is called by the media server when the stream is ready ls.HTTPMux.Handle("/live/video-to-video/start", ls.StartLiveVideo()) diff --git a/server/ai_process.go b/server/ai_process.go index 6fd92fdf0..0369fc2f1 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -35,6 +35,7 @@ const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct" const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" const defaultImageToTextModelID = "Salesforce/blip-image-captioning-large" const defaultLiveVideoToVideoModelID = "cumulo-autumn/stream-diffusion" +const defaultTextToSpeechModelID = "parler-tts/parler-tts-large-v1" var errWrongFormat = fmt.Errorf("result not in correct format") @@ -59,6 +60,9 @@ func parseBadRequestError(err error) *BadRequestError { if err == nil { return nil } + if err, ok := err.(*BadRequestError); ok { + return err + } const errorCode = "returned 400" if !strings.Contains(err.Error(), errorCode) { @@ -756,6 +760,122 @@ func submitSegmentAnything2(ctx context.Context, params aiRequestParams, sess *A return resp.JSON200, nil } +// CalculateTextToSpeechLatencyScore computes the time taken per character for a TextToSpeech request. +func CalculateTextToSpeechLatencyScore(took time.Duration, inCharacters int64) float64 { + if inCharacters <= 0 { + return 0 + } + + return took.Seconds() / float64(inCharacters) +} + +func processTextToSpeech(ctx context.Context, params aiRequestParams, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + audioResp, ok := resp.(*worker.AudioResponse) + if !ok { + return nil, errWrongFormat + } + + var result []byte + var data bytes.Buffer + var name string + writer := bufio.NewWriter(&data) + err = worker.ReadAudioB64DataUrl(audioResp.Audio.Url, writer) + if err == nil { + // orchestrator sent bae64 encoded result in .Url + name = string(core.RandomManifestID()) + ".wav" + writer.Flush() + result = data.Bytes() + } else { + // orchestrator sent download url, get the data + name = filepath.Base(audioResp.Audio.Url) + result, err = core.DownloadData(ctx, audioResp.Audio.Url) + if err != nil { + return nil, err + } + } + + newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(result), nil, 0) + if err != nil { + return nil, fmt.Errorf("error saving image to objectStore: %w", err) + } + + audioResp.Audio.Url = newUrl + return audioResp, nil +} + +func submitTextToSpeech(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) { + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "text-to-speech", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if req.Text == nil { + return nil, &BadRequestError{errors.New("text field is required")} + } + + textLength := len(*req.Text) + clog.V(common.VERBOSE).Infof(ctx, "Submitting text-to-speech request with text length: %d", textLength) + inCharacters := int64(textLength) + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, inCharacters) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "text-to-speech", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.GenTextToSpeechWithResponse(ctx, req, setHeaders) + took := time.Since(start) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "text-to-speech", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if resp.JSON200 == nil { + // TODO: Replace trim newline with better error spec from O + return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n")) + } + + // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update + if balUpdate != nil { + balUpdate.Status = ReceivedChange + } + + // TODO: Refine this rough estimate in future iterations + sess.LatencyScore = CalculateSegmentAnything2LatencyScore(took, inCharacters) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + + monitor.AIRequestFinished(ctx, "text-to-speech", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + var res worker.AudioResponse + if err := json.Unmarshal(resp.Body, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "text-to-speech", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + return &res, nil +} + // CalculateAudioToTextLatencyScore computes the time taken per second of audio for an audio-to-text request. func CalculateAudioToTextLatencyScore(took time.Duration, durationSeconds int64) float64 { if durationSeconds <= 0 { @@ -1218,6 +1338,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToText(ctx, params, sess, v) } + case worker.GenTextToSpeechJSONRequestBody: + cap = core.Capability_TextToSpeech + modelID = defaultTextToSpeechModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitTextToSpeech(ctx, params, sess, v) + } /* case worker.StartLiveVideoToVideoFormdataRequestBody: cap = core.Capability_LiveVideoToVideo diff --git a/server/ai_worker.go b/server/ai_worker.go index 4606a5b6e..97b2cda05 100644 --- a/server/ai_worker.go +++ b/server/ai_worker.go @@ -285,6 +285,35 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify return n.LLM(ctx, req) } reqOk = true + case "image-to-text": + var req worker.GenImageToTextMultipartRequestBody + err = json.Unmarshal(reqData.Request, &req) + if err != nil || req.ModelId == nil { + break + } + input, err = core.DownloadData(ctx, reqData.InputUrl) + if err != nil { + break + } + modelID = *req.ModelId + resultType = "application/json" + req.Image.InitFromBytes(input, "image") + processFn = func(ctx context.Context) (interface{}, error) { + return n.ImageToText(ctx, req) + } + reqOk = true + case "text-to-speech": + var req worker.GenTextToSpeechJSONRequestBody + err = json.Unmarshal(reqData.Request, &req) + if err != nil || req.ModelId == nil { + break + } + modelID = *req.ModelId + resultType = "audio/wav" + processFn = func(ctx context.Context) (interface{}, error) { + return n.TextToSpeech(ctx, req) + } + reqOk = true default: err = errors.New("AI request pipeline type not supported") } @@ -339,29 +368,29 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify // Parse data from runner to send back to orchestrator // ***-to-image gets base64 encoded string of binary image from runner // image-to-video processes frames from runner and returns ImageResponse with url to local file - imgResp, isImg := resp.(*worker.ImageResponse) - if isImg { - var imgBuf bytes.Buffer - for i, image := range imgResp.Images { + var resBuf bytes.Buffer + length := 0 + switch wkrResp := resp.(type) { + case *worker.ImageResponse: + for i, image := range wkrResp.Images { // read the data to binary and replace the url - length := 0 switch resultType { case "image/png": - err := worker.ReadImageB64DataUrl(image.Url, &imgBuf) + err := worker.ReadImageB64DataUrl(image.Url, &resBuf) if err != nil { clog.Errorf(ctx, "AI Worker failed to save image from data url err=%q", err) sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, &body, err) return } - length = imgBuf.Len() - imgResp.Images[i].Url = fmt.Sprintf("%v.png", core.RandomManifestID()) // update json response to track filename attached + length = resBuf.Len() + wkrResp.Images[i].Url = fmt.Sprintf("%v.png", core.RandomManifestID()) // update json response to track filename attached // create the part w.SetBoundary(boundary) hdrs := textproto.MIMEHeader{ "Content-Type": {resultType}, "Content-Length": {strconv.Itoa(length)}, - "Content-Disposition": {"attachment; filename=" + imgResp.Images[i].Url}, + "Content-Disposition": {"attachment; filename=" + wkrResp.Images[i].Url}, } fw, err := w.CreatePart(hdrs) if err != nil { @@ -369,8 +398,8 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, nil, err) return } - io.Copy(fw, &imgBuf) - imgBuf.Reset() + io.Copy(fw, &resBuf) + resBuf.Reset() case "video/mp4": // transcoded result is saved as local file // TODO: enhance this to return the []bytes from transcoding in n.ImageToVideo create the part @@ -381,12 +410,12 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify return } defer os.Remove(image.Url) - imgResp.Images[i].Url = fmt.Sprintf("%v.mp4", core.RandomManifestID()) + wkrResp.Images[i].Url = fmt.Sprintf("%v.mp4", core.RandomManifestID()) w.SetBoundary(boundary) hdrs := textproto.MIMEHeader{ "Content-Type": {resultType}, "Content-Length": {strconv.Itoa(len(f))}, - "Content-Disposition": {"attachment; filename=" + imgResp.Images[i].Url}, + "Content-Disposition": {"attachment; filename=" + wkrResp.Images[i].Url}, } fw, err := w.CreatePart(hdrs) if err != nil { @@ -398,7 +427,31 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify } } // update resp for image.Url updates - resp = imgResp + resp = wkrResp + case *worker.AudioResponse: + err := worker.ReadAudioB64DataUrl(wkrResp.Audio.Url, &resBuf) + if err != nil { + clog.Errorf(ctx, "AI Worker failed to save image from data url err=%q", err) + sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, &body, err) + return + } + length = resBuf.Len() + wkrResp.Audio.Url = fmt.Sprintf("%v.wav", core.RandomManifestID()) // update json response to track filename attached + // create the part + w.SetBoundary(boundary) + hdrs := textproto.MIMEHeader{ + "Content-Type": {resultType}, + "Content-Length": {strconv.Itoa(length)}, + "Content-Disposition": {"attachment; filename=" + wkrResp.Audio.Url}, + } + fw, err := w.CreatePart(hdrs) + if err != nil { + clog.Errorf(ctx, "Could not create multipart part err=%q", err) + sendAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, modelID, httpc, contentType, nil, err) + return + } + io.Copy(fw, &resBuf) + resBuf.Reset() } // add the json to the response diff --git a/server/ai_worker_test.go b/server/ai_worker_test.go index 163fc7a02..4ecbd8768 100644 --- a/server/ai_worker_test.go +++ b/server/ai_worker_test.go @@ -204,16 +204,30 @@ func TestRunAIJob(t *testing.T) { expectedErr: "", expectedOutputs: 1, }, + { + name: "ImageToText_Success", + notify: createAIJob(8, "image-to-text", modelId, parsedURL.String()+"/image.png"), + pipeline: "image-to-text", + expectedErr: "", + expectedOutputs: 1, + }, + { + name: "TextToSpeech_Success", + notify: createAIJob(9, "text-to-speech", modelId, ""), + pipeline: "text-to-speech", + expectedErr: "", + expectedOutputs: 1, + }, { name: "UnsupportedPipeline", - notify: createAIJob(8, "unsupported-pipeline", modelId, ""), + notify: createAIJob(10, "unsupported-pipeline", modelId, ""), pipeline: "unsupported-pipeline", expectedErr: "AI request validation failed for", expectedOutputs: 0, }, { name: "InvalidRequestData", - notify: createAIJob(9, "text-to-image-invalid", modelId, ""), + notify: createAIJob(11, "text-to-image-invalid", modelId, ""), pipeline: "text-to-image", expectedErr: "AI request validation failed for", expectedOutputs: 0, @@ -312,6 +326,24 @@ func TestRunAIJob(t *testing.T) { assert.Equal(len(results.Files), 0) expectedResp, _ := wkr.LLM(context.Background(), worker.GenLLMFormdataRequestBody{}) assert.Equal(expectedResp, &jsonRes) + case "image-to-text": + res, _ := json.Marshal(results.Results) + var jsonRes worker.ImageToTextResponse + json.Unmarshal(res, &jsonRes) + + assert.Equal("8", headers.Get("TaskId")) + assert.Equal(len(results.Files), 0) + expectedResp, _ := wkr.ImageToText(context.Background(), worker.GenImageToTextMultipartRequestBody{}) + assert.Equal(expectedResp, &jsonRes) + case "text-to-speech": + audResp, ok := results.Results.(worker.AudioResponse) + assert.True(ok) + assert.Equal("9", headers.Get("TaskId")) + assert.Equal(len(results.Files), 1) + expectedResp, _ := wkr.TextToSpeech(context.Background(), worker.GenTextToSpeechJSONRequestBody{}) + var respFile bytes.Buffer + worker.ReadAudioB64DataUrl(expectedResp.Audio.Url, &respFile) + assert.Equal(len(results.Files[audResp.Audio.Url]), respFile.Len()) } } }) @@ -341,6 +373,13 @@ func createAIJob(taskId int64, pipeline, modelId, inputUrl string) *net.NotifyAI req = worker.GenSegmentAnything2MultipartRequestBody{ModelId: &modelId, Image: inputFile} case "llm": req = worker.GenLLMFormdataRequestBody{Prompt: "tell me a story", ModelId: &modelId} + case "image-to-text": + inputFile.InitFromBytes(nil, inputUrl) + req = worker.GenImageToImageMultipartRequestBody{Prompt: "test prompt", ModelId: &modelId, Image: inputFile} + case "text-to-speech": + desc := "a young adult" + text := "let me tell you a story" + req = worker.GenTextToSpeechJSONRequestBody{Description: &desc, ModelId: &modelId, Text: &text} case "unsupported-pipeline": req = worker.GenTextToImageJSONRequestBody{Prompt: "test prompt", ModelId: &modelId} case "text-to-image-invalid": @@ -576,6 +615,17 @@ func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTex } } +func (a *stubAIWorker) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) { + a.Called++ + if a.Err != nil { + return nil, a.Err + } else { + return &worker.AudioResponse{Audio: worker.MediaURL{ + Url: "data:audio/wav;base64,UklGRhYAAABXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YQAAAAA="}, + }, nil + } +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { a.Called++ return nil diff --git a/server/rpc.go b/server/rpc.go index 62ab1ba46..c5a11b628 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -76,6 +76,7 @@ type Orchestrator interface { LLM(ctx context.Context, requestID string, req worker.GenLLMFormdataRequestBody) (interface{}, error) SegmentAnything2(ctx context.Context, requestID string, req worker.GenSegmentAnything2MultipartRequestBody) (interface{}, error) ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) + TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance diff --git a/server/rpc_test.go b/server/rpc_test.go index 73592a6a0..162dacc2e 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -219,6 +219,10 @@ func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, requestID strin func (r *stubOrchestrator) ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) { return nil, nil } +func (r *stubOrchestrator) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) { + return nil, nil +} + func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true } @@ -1418,6 +1422,9 @@ func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, requestID strin func (r *mockOrchestrator) ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) { return nil, nil } +func (r *mockOrchestrator) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) { + return nil, nil +} func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true }