diff --git a/server/ai_http.go b/server/ai_http.go index ac9d59050..c26e42c3a 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -786,7 +786,7 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core break } results = parsedResp - case "audio-to-text", "segment-anything-2", "llm": + case "audio-to-text", "segment-anything-2", "llm", "image-to-text": err := json.Unmarshal(body, &results) if err != nil { glog.Error("Error getting results json:", err) diff --git a/server/ai_worker.go b/server/ai_worker.go index 28c1ec5fd..97b2cda05 100644 --- a/server/ai_worker.go +++ b/server/ai_worker.go @@ -285,6 +285,23 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify return n.LLM(ctx, req) } reqOk = true + case "image-to-text": + var req worker.GenImageToTextMultipartRequestBody + err = json.Unmarshal(reqData.Request, &req) + if err != nil || req.ModelId == nil { + break + } + input, err = core.DownloadData(ctx, reqData.InputUrl) + if err != nil { + break + } + modelID = *req.ModelId + resultType = "application/json" + req.Image.InitFromBytes(input, "image") + processFn = func(ctx context.Context) (interface{}, error) { + return n.ImageToText(ctx, req) + } + reqOk = true case "text-to-speech": var req worker.GenTextToSpeechJSONRequestBody err = json.Unmarshal(reqData.Request, &req) diff --git a/server/ai_worker_test.go b/server/ai_worker_test.go index 741d96583..4ecbd8768 100644 --- a/server/ai_worker_test.go +++ b/server/ai_worker_test.go @@ -204,6 +204,13 @@ func TestRunAIJob(t *testing.T) { expectedErr: "", expectedOutputs: 1, }, + { + name: "ImageToText_Success", + notify: createAIJob(8, "image-to-text", modelId, parsedURL.String()+"/image.png"), + pipeline: "image-to-text", + expectedErr: "", + expectedOutputs: 1, + }, { name: "TextToSpeech_Success", notify: createAIJob(9, "text-to-speech", modelId, ""), @@ -319,6 +326,15 @@ func TestRunAIJob(t *testing.T) { assert.Equal(len(results.Files), 0) expectedResp, _ := wkr.LLM(context.Background(), worker.GenLLMFormdataRequestBody{}) assert.Equal(expectedResp, &jsonRes) + case "image-to-text": + res, _ := json.Marshal(results.Results) + var jsonRes worker.ImageToTextResponse + json.Unmarshal(res, &jsonRes) + + assert.Equal("8", headers.Get("TaskId")) + assert.Equal(len(results.Files), 0) + expectedResp, _ := wkr.ImageToText(context.Background(), worker.GenImageToTextMultipartRequestBody{}) + assert.Equal(expectedResp, &jsonRes) case "text-to-speech": audResp, ok := results.Results.(worker.AudioResponse) assert.True(ok) @@ -357,6 +373,9 @@ func createAIJob(taskId int64, pipeline, modelId, inputUrl string) *net.NotifyAI req = worker.GenSegmentAnything2MultipartRequestBody{ModelId: &modelId, Image: inputFile} case "llm": req = worker.GenLLMFormdataRequestBody{Prompt: "tell me a story", ModelId: &modelId} + case "image-to-text": + inputFile.InitFromBytes(nil, inputUrl) + req = worker.GenImageToImageMultipartRequestBody{Prompt: "test prompt", ModelId: &modelId, Image: inputFile} case "text-to-speech": desc := "a young adult" text := "let me tell you a story"