Skip to content

Commit

Permalink
Merge branch 'ai-video' into new-payment-mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
leszko authored Nov 5, 2024
2 parents 90f04dc + 836d006 commit 6f36d2d
Show file tree
Hide file tree
Showing 15 changed files with 403 additions and 46 deletions.
1 change: 1 addition & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
Models: make(map[string]*core.ModelConstraint),
}
}

model, exists := capabilityConstraints[pipelineCap].Models[config.ModelID]
if !exists {
capabilityConstraints[pipelineCap].Models[config.ModelID] = modelConstraint
Expand Down
1 change: 1 addition & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ var (
"video/mp2t": ".ts",
"video/mp4": ".mp4",
"image/png": ".png",
"audio/wav": ".wav",
}
)

Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type AI interface {
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
4 changes: 4 additions & 0 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTex
return &worker.ImageToTextResponse{Text: "Transcribed text"}, nil
}

func (a *stubAIWorker) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {
return &worker.AudioResponse{Audio: worker.MediaURL{Url: "http://example.com/audio.wav"}}, nil
}

func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error {
return nil
}
Expand Down
116 changes: 90 additions & 26 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,68 +422,95 @@ func (n *LivepeerNode) saveLocalAIWorkerResults(ctx context.Context, results int
ext, _ := common.MimeTypeToExtension(contentType)
fileName := string(RandomManifestID()) + ext

imgRes, ok := results.(worker.ImageResponse)
if !ok {
// worker.TextResponse is JSON, no file save needed
return results, nil
}
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}

var buf bytes.Buffer
for i, image := range imgRes.Images {
buf.Reset()
err := worker.ReadImageB64DataUrl(image.Url, &buf)
if err != nil {
// try to load local file (image to video returns local file)
f, err := os.ReadFile(image.Url)
switch resp := results.(type) {
case worker.ImageResponse:
for i, image := range resp.Images {
buf.Reset()
err := worker.ReadImageB64DataUrl(image.Url, &buf)
if err != nil {
// try to load local file (image to video returns local file)
f, err := os.ReadFile(image.Url)
if err != nil {
return nil, err
}
buf = *bytes.NewBuffer(f)
}

osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0)
if err != nil {
return nil, err
}
buf = *bytes.NewBuffer(f)

resp.Images[i].Url = osUrl
}

results = resp
case worker.AudioResponse:
err := worker.ReadAudioB64DataUrl(resp.Audio.Url, &buf)
if err != nil {
return nil, err
}

osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewBuffer(buf.Bytes()), nil, 0)
if err != nil {
return nil, err
}
resp.Audio.Url = osUrl

imgRes.Images[i].Url = osUrl
results = resp
}

return imgRes, nil
//no file response to save, response is text
return results, nil
}

func (n *LivepeerNode) saveRemoteAIWorkerResults(ctx context.Context, results *RemoteAIWorkerResult, requestID string) (*RemoteAIWorkerResult, error) {
if drivers.NodeStorage == nil {
return nil, fmt.Errorf("Missing local storage")
}

// save the file data to node and provide url for download
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}
// worker.ImageResponse used by ***-to-image and image-to-video require saving binary data for download
// worker.AudioResponse used to text-to-speech also requires saving binary data for download
// other pipelines do not require saving data since they are text responses
imgResp, isImg := results.Results.(worker.ImageResponse)
if isImg {
for idx := range imgResp.Images {
fileName := imgResp.Images[idx].Url
// save the file data to node and provide url for download
storage, exists := n.StorageConfigs[requestID]
if !exists {
return nil, errors.New("no storage available for request")
}
switch resp := results.Results.(type) {
case worker.ImageResponse:
for idx := range resp.Images {
fileName := resp.Images[idx].Url
osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0)
if err != nil {
return nil, err
}

imgResp.Images[idx].Url = osUrl
resp.Images[idx].Url = osUrl
delete(results.Files, fileName)
}

// update results for url updates
results.Results = imgResp
results.Results = resp
case worker.AudioResponse:
fileName := resp.Audio.Url
osUrl, err := storage.OS.SaveData(ctx, fileName, bytes.NewReader(results.Files[fileName]), nil, 0)
if err != nil {
return nil, err
}

resp.Audio.Url = osUrl
delete(results.Files, fileName)

results.Results = resp
}

// no file response to save, response is text
return results, nil
}

Expand Down Expand Up @@ -789,6 +816,39 @@ func (orch *orchestrator) ImageToText(ctx context.Context, requestID string, req
return res.Results, nil
}

func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) {
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.TextToSpeech(ctx, req)
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "audio/wav")
} else {
clog.Errorf(ctx, "Error processing with local ai worker err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}
}

// remote ai worker proceses job
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "text-to-speech", *req.ModelId, "", AIJobRequestData{Request: req})
if err != nil {
return nil, err
}

res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID)
if err != nil {
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "text-to-speech", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}

return res.Results, nil
}

// only used for sending work to remote AI worker
func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) {
node := orch.node
Expand Down Expand Up @@ -959,6 +1019,10 @@ func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequest
return n.AIWorker.LLM(ctx, req)
}

func (n *LivepeerNode) TextToSpeech(ctx context.Context, req worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) {
return n.AIWorker.TextToSpeech(ctx, req)
}

// transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline.
func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult {
ctx = clog.AddOrchSessionID(ctx, sessionID)
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ const (
Capability_LLM Capability = 33
Capability_ImageToText Capability = 34
Capability_LiveVideoToVideo Capability = 35
Capability_TextToSpeech Capability = 36
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -122,6 +123,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_LLM: "Llm",
Capability_ImageToText: "Image to text",
Capability_LiveVideoToVideo: "Live video to video",
Capability_TextToSpeech: "Text to speech",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -214,6 +216,7 @@ func OptionalCapabilities() []Capability {
Capability_AudioToText,
Capability_SegmentAnything2,
Capability_ImageToText,
Capability_TextToSpeech,
}
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.9.2
github.com/livepeer/ai-worker v0.11.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.9.2 h1:kgXb6sjfi93pJxxsAtWyAGo53/+gHsf7JMRVApor+zU=
github.com/livepeer/ai-worker v0.9.2/go.mod h1:/Deme7XXRP4BiYXt/j694Ygw+dh8rWJdikJsKY64sjE=
github.com/livepeer/ai-worker v0.11.0 h1:prbRRBgCIrECUuZFWuyN6z3QZLfygYqBKYMleT+I7o4=
github.com/livepeer/ai-worker v0.11.0/go.mod h1:GjQuPmz69UO53WVtqzB9Ygok5MmKCGNuobbfMXH7zgw=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
44 changes: 43 additions & 1 deletion server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/golang/glog"
Expand Down Expand Up @@ -57,6 +58,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(lp.ImageToText()))
lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo()))
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(lp.TextToSpeech()))
// Additionally, there is the '/aiResults' endpoint registered in server/rpc.go

return nil
Expand Down Expand Up @@ -269,6 +271,23 @@ func (h *lphttp) StartLiveVideoToVideo() http.Handler {
})
}

func (h *lphttp) TextToSpeech() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

var req worker.GenTextToSpeechJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -448,6 +467,18 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.GenTextToSpeechJSONRequestBody:
pipeline = "text-to-speech"
cap = core.Capability_TextToSpeech
modelID = *v.ModelId

submitFn = func(ctx context.Context) (interface{}, error) {
return orch.TextToSpeech(ctx, requestID, v)
}

// TTS pricing is typically in characters, including punctuation.
words := utf8.RuneCountInString(*v.Text)
outPixels = int64(1000 * words)
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down Expand Up @@ -551,6 +582,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels)
case worker.GenImageToTextMultipartRequestBody:
latencyScore = CalculateImageToTextLatencyScore(took, outPixels)
case worker.GenTextToSpeechJSONRequestBody:
latencyScore = CalculateTextToSpeechLatencyScore(took, outPixels)
}

var pricePerAIUnit float64
Expand Down Expand Up @@ -753,13 +786,22 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core
break
}
results = parsedResp
case "audio-to-text", "segment-anything-2", "llm":
case "audio-to-text", "segment-anything-2", "llm", "image-to-text":
err := json.Unmarshal(body, &results)
if err != nil {
glog.Error("Error getting results json:", err)
wkrResult.Err = err
break
}
case "text-to-speech":
var parsedResp worker.AudioResponse
err := json.Unmarshal(body, &parsedResp)
if err != nil {
glog.Error("Error getting results json:", err)
wkrResult.Err = err
break
}
results = parsedResp
}

wkrResult.Results = results
Expand Down
1 change: 1 addition & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2)))
ls.HTTPMux.Handle("/image-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToTextMultipartRequestBody], processImageToText)))
ls.HTTPMux.Handle("/text-to-speech", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToSpeechJSONRequestBody], processTextToSpeech)))

// This is called by the media server when the stream is ready
ls.HTTPMux.Handle("/live/video-to-video/start", ls.StartLiveVideo())
Expand Down
Loading

0 comments on commit 6f36d2d

Please sign in to comment.