Skip to content

Commit

Permalink
Fix after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
mjh1 committed Oct 22, 2024
1 parent 5307c43 commit 7515487
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 77 deletions.
14 changes: 0 additions & 14 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1232,20 +1232,6 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
capabilityConstraints[pipelineCap] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
case "image-to-text":
_, ok := capabilityConstraints[core.Capability_ImageToText]
if !ok {
aiCaps = append(aiCaps, core.Capability_ImageToText)
capabilityConstraints[core.Capability_ImageToText] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_ImageToText].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_ImageToText, config.ModelID, autoPrice)
}
}
model, exists := capabilityConstraints[pipelineCap].Models[config.ModelID]
if !exists {
Expand Down
40 changes: 40 additions & 0 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,42 @@ func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.
return res.Results, nil
}

func (orch *orchestrator) ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) {
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
// no file response to save, response is text sent back to gateway
return orch.node.ImageToText(ctx, req)
}

// remote ai worker proceses job
imageBytes, err := req.Image.Bytes()
if err != nil {
return nil, err
}

inputUrl, err := orch.SaveAIRequestInput(ctx, requestID, imageBytes)
if err != nil {
return nil, err
}
req.Image.InitFromBytes(nil, "")

res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "image-to-text", *req.ModelId, inputUrl, AIJobRequestData{Request: req, InputUrl: inputUrl})
if err != nil {
return nil, err
}

res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID)
if err != nil {
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "image-to-text", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}

return res.Results, nil
}

// only used for sending work to remote AI worker
func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) {
node := orch.node
Expand Down Expand Up @@ -855,6 +891,10 @@ func (n *LivepeerNode) Upscale(ctx context.Context, req worker.GenUpscaleMultipa
func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) {
return n.AIWorker.ImageToText(ctx, req)
}
func (n *LivepeerNode) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1
Expand Down
4 changes: 0 additions & 4 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes
orch.node.TranscoderManager.transcoderResults(tcID, res)
}

func (orch *orchestrator) ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) {
return orch.node.ImageToText(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.7.0
github.com/livepeer/ai-worker v0.8.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2
Expand Down Expand Up @@ -251,5 +251,3 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => /Users/max/go/src/github.com/livepeer/ai-worker
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,8 @@ github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+O
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw=
github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.8.0 h1:z4gczVYl47hFkzV1FHVi5aFaqmA27LyUIEmqAEwhR9U=
github.com/livepeer/ai-worker v0.8.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
2 changes: 1 addition & 1 deletion server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
cap = core.Capability_ImageToText
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.ImageToText(ctx, v)
return orch.ImageToText(ctx, requestID, v)
}

imageRdr, err := v.Image.Reader()
Expand Down
55 changes: 1 addition & 54 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenAudioToTextMultipartRequestBody], processAudioToText)))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2)))
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(ls.ImageToText()))
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText)))

return nil
}
Expand Down Expand Up @@ -361,56 +361,3 @@ func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) ImageToText() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenImageToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received ImageToText request imageSize=%v model_id=%v", req.Image.FileSize(), *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processImageToText(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed ImageToText request model_id=%v took=%v", *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}
2 changes: 1 addition & 1 deletion server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type Orchestrator interface {
AudioToText(ctx context.Context, requestID string, req worker.GenAudioToTextMultipartRequestBody) (interface{}, error)
LLM(ctx context.Context, requestID string, req worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(ctx context.Context, requestID string, req worker.GenSegmentAnything2MultipartRequestBody) (interface{}, error)
ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down

0 comments on commit 7515487

Please sign in to comment.