From db926896bd7a0d51f8d94fc7c5a78dfbf45b0dda Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 5 Jan 2024 12:04:46 -0500 Subject: [PATCH] Revert "[Refactor]: Core/API Split" (#1550) Revert "[Refactor]: Core/API Split (#1506)" This reverts commit ab7b4d5ee9448e533a342bd1771393acd2967191. --- .gitignore | 4 +- Dockerfile | 2 +- api/api.go | 302 ++++++ {core/http => api}/api_test.go | 86 +- {core/http => api}/apt_suite_test.go | 2 +- {core => api}/backend/embeddings.go | 60 +- api/backend/image.go | 61 ++ api/backend/llm.go | 167 ++++ {core => api}/backend/options.go | 12 +- api/backend/transcript.go | 39 + {core => api}/backend/tts.go | 16 +- {pkg/schema => api/config}/config.go | 273 +++--- {pkg/schema => api/config}/config_test.go | 17 +- {pkg/schema => api/config}/prediction.go | 2 +- api/localai/backend_monitor.go | 162 ++++ api/localai/gallery.go | 326 +++++++ api/localai/localai.go | 32 + api/openai/chat.go | 399 ++++++++ api/openai/completion.go | 199 ++++ api/openai/edit.go | 94 ++ api/openai/embeddings.go | 78 ++ api/openai/image.go | 239 +++++ api/openai/inference.go | 55 ++ {core/http/endpoints => api}/openai/list.go | 20 +- api/openai/request.go | 336 +++++++ api/openai/transcription.go | 71 ++ .../options/options.go | 94 +- {pkg => api}/schema/openai.go | 4 +- {pkg => api}/schema/whisper.go | 8 +- backend/go/transcribe/transcript.go | 8 +- backend/go/transcribe/whisper.go | 4 +- config/.keep | 0 core/backend/image.go | 210 ----- core/backend/llm.go | 861 ------------------ core/backend/transcription.go | 52 -- core/http/api.go | 169 ---- .../http/endpoints/localai/backend_monitor.go | 34 - core/http/endpoints/localai/gallery.go | 148 --- core/http/endpoints/localai/metrics.go | 42 - core/http/endpoints/localai/tts.go | 25 - core/http/endpoints/openai/chat.go | 97 -- core/http/endpoints/openai/completion.go | 91 -- core/http/endpoints/openai/edit.go | 34 - core/http/endpoints/openai/embeddings.go | 35 - core/http/endpoints/openai/image.go | 48 - core/http/endpoints/openai/request.go | 57 -- core/http/endpoints/openai/transcription.go | 49 - core/mqtt/manager.go | 24 - core/services/backend_monitor.go | 138 --- core/services/config.go | 157 ---- core/services/gallery.go | 160 ---- core/services/metrics.go | 29 - core/startup/config_file_watcher.go | 100 -- core/startup/startup.go | 93 -- docs/content/advanced/development.md | 47 - docs/content/features/text-to-audio.md | 7 +- go.mod | 4 +- go.sum | 11 +- main.go | 120 +-- metrics/metrics.go | 83 ++ pkg/gallery/gallery.go | 6 +- pkg/gallery/models.go | 69 +- pkg/gallery/models_test.go | 8 +- pkg/gallery/op.go | 18 - pkg/gallery/request_test.go | 2 +- pkg/grpc/base/base.go | 7 +- pkg/grpc/client.go | 8 +- pkg/grpc/interface.go | 4 +- pkg/grpc/proto/backend.pb.go | 4 +- pkg/grpc/proto/backend_grpc.pb.go | 55 +- pkg/model/initializers.go | 4 +- pkg/model/loader.go | 2 +- pkg/model/options.go | 49 +- pkg/schema/localai.go | 39 - pkg/utils/file.go | 81 -- pkg/utils/uri.go | 65 +- tests/integration/reflect_test.go | 8 +- 77 files changed, 3101 insertions(+), 3425 deletions(-) create mode 100644 api/api.go rename {core/http => api}/api_test.go (93%) rename {core/http => api}/apt_suite_test.go (90%) rename {core => api}/backend/embeddings.go (50%) create mode 100644 api/backend/image.go create mode 100644 api/backend/llm.go rename {core => api}/backend/options.go (90%) create mode 100644 api/backend/transcript.go rename {core => api}/backend/tts.go (76%) rename {pkg/schema => api/config}/config.go (60%) rename {pkg/schema => api/config}/config_test.go (74%) rename {pkg/schema => api/config}/prediction.go (99%) create mode 100644 api/localai/backend_monitor.go create mode 100644 api/localai/gallery.go create mode 100644 api/localai/localai.go create mode 100644 api/openai/chat.go create mode 100644 api/openai/completion.go create mode 100644 api/openai/edit.go create mode 100644 api/openai/embeddings.go create mode 100644 api/openai/image.go create mode 100644 api/openai/inference.go rename {core/http/endpoints => api}/openai/list.go (70%) create mode 100644 api/openai/request.go create mode 100644 api/openai/transcription.go rename pkg/schema/startup_options.go => api/options/options.go (68%) rename {pkg => api}/schema/openai.go (97%) rename {pkg => api}/schema/whisper.go (60%) delete mode 100644 config/.keep delete mode 100644 core/backend/image.go delete mode 100644 core/backend/llm.go delete mode 100644 core/backend/transcription.go delete mode 100644 core/http/api.go delete mode 100644 core/http/endpoints/localai/backend_monitor.go delete mode 100644 core/http/endpoints/localai/gallery.go delete mode 100644 core/http/endpoints/localai/metrics.go delete mode 100644 core/http/endpoints/localai/tts.go delete mode 100644 core/http/endpoints/openai/chat.go delete mode 100644 core/http/endpoints/openai/completion.go delete mode 100644 core/http/endpoints/openai/edit.go delete mode 100644 core/http/endpoints/openai/embeddings.go delete mode 100644 core/http/endpoints/openai/image.go delete mode 100644 core/http/endpoints/openai/request.go delete mode 100644 core/http/endpoints/openai/transcription.go delete mode 100644 core/mqtt/manager.go delete mode 100644 core/services/backend_monitor.go delete mode 100644 core/services/config.go delete mode 100644 core/services/gallery.go delete mode 100644 core/services/metrics.go delete mode 100644 core/startup/config_file_watcher.go delete mode 100644 core/startup/startup.go create mode 100644 metrics/metrics.go delete mode 100644 pkg/gallery/op.go delete mode 100644 pkg/schema/localai.go delete mode 100644 pkg/utils/file.go diff --git a/.gitignore b/.gitignore index a8e265a578c3..df00829cf890 100644 --- a/.gitignore +++ b/.gitignore @@ -19,8 +19,8 @@ LocalAI local-ai # prevent above rules from omitting the helm chart !charts/* -# prevent above rules from omitting the core/**/localai folder -!core/**/localai +# prevent above rules from omitting the api/localai folder +!api/localai # Ignore models models/* diff --git a/Dockerfile b/Dockerfile index 3e9b5f6be530..7f7ee8177e43 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,7 +88,7 @@ ENV NVIDIA_VISIBLE_DEVICES=all WORKDIR /build COPY . . -COPY .git/ .git/ +COPY .git . RUN make prepare # stablediffusion does not tolerate a newer version of abseil, build it first diff --git a/api/api.go b/api/api.go new file mode 100644 index 000000000000..365346bdbb41 --- /dev/null +++ b/api/api.go @@ -0,0 +1,302 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/localai" + "github.com/go-skynet/LocalAI/api/openai" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/metrics" + "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/recover" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { + options := options.NewOptions(opts...) + + zerolog.SetGlobalLevel(zerolog.InfoLevel) + if options.Debug { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) + + modelPath := options.Loader.ModelPath + if len(options.ModelsURL) > 0 { + for _, url := range options.ModelsURL { + if utils.LooksLikeURL(url) { + // md5 of model name + md5Name := utils.MD5(url) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + }) + if err != nil { + log.Error().Msgf("error loading model: %s", err.Error()) + } + } + } + } + } + + cl := config.NewConfigLoader() + if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { + log.Error().Msgf("error loading config files: %s", err.Error()) + } + + if options.ConfigFile != "" { + if err := cl.LoadConfigFile(options.ConfigFile); err != nil { + log.Error().Msgf("error loading config file: %s", err.Error()) + } + } + + if err := cl.Preload(options.Loader.ModelPath); err != nil { + log.Error().Msgf("error downloading models: %s", err.Error()) + } + + if options.PreloadJSONModels != "" { + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, err + } + } + + if options.PreloadModelsFromPath != "" { + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, err + } + } + + if options.Debug { + for _, v := range cl.ListConfigs() { + cfg, _ := cl.GetConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) + } + } + + if options.AssetsDestination != "" { + // Extract files from the embedded FS + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) + if err != nil { + log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) + } + } + + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + options.Loader.StopAllGRPC() + }() + + if options.WatchDog { + wd := model.NewWatchDog( + options.Loader, + options.WatchDogBusyTimeout, + options.WatchDogIdleTimeout, + options.WatchDogBusy, + options.WatchDogIdle) + options.Loader.SetWatchDog(wd) + go wd.Run() + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + wd.Shutdown() + }() + } + + return options, cl, nil +} + +func App(opts ...options.AppOption) (*fiber.App, error) { + + options, cl, err := Startup(opts...) + if err != nil { + return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) + } + + // Return errors as JSON responses + app := fiber.New(fiber.Config{ + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.DisableMessage, + // Override default error handler + ErrorHandler: func(ctx *fiber.Ctx, err error) error { + // Status code defaults to 500 + code := fiber.StatusInternalServerError + + // Retrieve the custom status code if it's a *fiber.Error + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + } + + // Send custom error page + return ctx.Status(code).JSON( + schema.ErrorResponse{ + Error: &schema.APIError{Message: err.Error(), Code: code}, + }, + ) + }, + }) + + if options.Debug { + app.Use(logger.New(logger.Config{ + Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", + })) + } + + // Default middleware config + app.Use(recover.New()) + if options.Metrics != nil { + app.Use(metrics.APIMiddleware(options.Metrics)) + } + + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. + auth := func(c *fiber.Ctx) error { + if len(options.ApiKeys) == 0 { + return c.Next() + } + + // Check for api_keys.json file + fileContent, err := os.ReadFile("api_keys.json") + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) + } + + // Add file keys to options.ApiKeys + options.ApiKeys = append(options.ApiKeys, fileKeys...) + } + + if len(options.ApiKeys) == 0 { + return c.Next() + } + + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + for _, key := range options.ApiKeys { + if apiKey == key { + return c.Next() + } + } + + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + + } + + if options.CORS { + var c func(ctx *fiber.Ctx) error + if options.CORSAllowOrigins == "" { + c = cors.New() + } else { + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) + } + + app.Use(c) + } + + // LocalAI API endpoints + galleryService := localai.NewGalleryService(options.Loader.ModelPath) + galleryService.Start(options.Context, cl) + + app.Get("/version", auth, func(c *fiber.Ctx) error { + return c.JSON(struct { + Version string `json:"version"` + }{Version: internal.PrintableVersion()}) + }) + + modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) + app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) + app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) + app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) + app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) + app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) + app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) + app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) + + // openAI compatible API endpoint + + // chat + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) + + // edit + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) + app.Post("/edits", auth, openai.EditEndpoint(cl, options)) + + // completion + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) + + // embeddings + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + + // audio + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) + + // images + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) + + if options.ImageDir != "" { + app.Static("/generated-images", options.ImageDir) + } + + if options.AudioDir != "" { + app.Static("/generated-audio", options.AudioDir) + } + + ok := func(c *fiber.Ctx) error { + return c.SendStatus(200) + } + + // Kubernetes health checks + app.Get("/healthz", ok) + app.Get("/readyz", ok) + + // Experimental Backend Statistics Module + backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now + app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) + app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) + + // models + app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) + app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) + + app.Get("/metrics", metrics.MetricsHandler()) + + return app, nil +} diff --git a/core/http/api_test.go b/api/api_test.go similarity index 93% rename from core/http/api_test.go rename to api/api_test.go index 54c76a640df5..a71b450ada7d 100644 --- a/core/http/api_test.go +++ b/api/api_test.go @@ -1,4 +1,4 @@ -package http_test +package api_test import ( "bytes" @@ -13,12 +13,11 @@ import ( "path/filepath" "runtime" - server "github.com/go-skynet/LocalAI/core/http" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/core/startup" + . "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" @@ -119,15 +118,16 @@ var backendAssets embed.FS var _ = Describe("API test", func() { var app *fiber.App + var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc var tmpdir string - commonOpts := []schema.AppOption{ - schema.WithDebug(true), - schema.WithDisableMessage(true), + commonOpts := []options.AppOption{ + options.WithDebug(true), + options.WithDisableMessage(true), } Context("API with ephemeral models", func() { @@ -136,6 +136,7 @@ var _ = Describe("API test", func() { tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) + modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) g := []gallery.GalleryModel{ @@ -162,20 +163,15 @@ var _ = Describe("API test", func() { }, } - metricsService, err := services.SetupMetrics() + metricsService, err := metrics.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - cl, ml, options, err := startup.Startup( + app, err = App( append(commonOpts, - schema.WithMetrics(metricsService), - schema.WithContext(c), - schema.WithGalleries(galleries), - schema.WithModelPath(tmpdir), - schema.WithBackendAssets(backendAssets), - schema.WithBackendAssetsOutput(tmpdir))...) - - Expect(err).ToNot(HaveOccurred()) - app, err = server.App(cl, ml, options) + options.WithMetrics(metricsService), + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -479,6 +475,7 @@ var _ = Describe("API test", func() { tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) + modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) galleries := []gallery.Gallery{ @@ -488,22 +485,21 @@ var _ = Describe("API test", func() { }, } - metricsService, err := services.SetupMetrics() + metricsService, err := metrics.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - cl, ml, options, err := startup.Startup( + app, err = App( append(commonOpts, - schema.WithContext(c), - schema.WithMetrics(metricsService), - schema.WithAudioDir(tmpdir), - schema.WithImageDir(tmpdir), - schema.WithGalleries(galleries), - schema.WithModelPath(tmpdir), - schema.WithBackendAssets(backendAssets), - schema.WithBackendAssetsOutput(tmpdir))..., + options.WithContext(c), + options.WithMetrics(metricsService), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = server.App(cl, ml, options) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -594,21 +590,20 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { + modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - metricsService, err := services.SetupMetrics() + metricsService, err := metrics.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - cl, ml, options, err := startup.Startup( + app, err = App( append(commonOpts, - schema.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), - schema.WithContext(c), - schema.WithModelPath(os.Getenv("MODELS_PATH")), - schema.WithMetrics(metricsService), + options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + options.WithContext(c), + options.WithModelLoader(modelLoader), + options.WithMetrics(metricsService), )...) Expect(err).ToNot(HaveOccurred()) - app, err = server.App(cl, ml, options) - Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -807,21 +802,20 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { + modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - metricsService, err := services.SetupMetrics() + metricsService, err := metrics.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - cl, ml, options, err := startup.Startup( + app, err = App( append(commonOpts, - schema.WithContext(c), - schema.WithMetrics(metricsService), - schema.WithModelPath(os.Getenv("MODELS_PATH")), - schema.WithConfigFile(os.Getenv("CONFIG_FILE")))..., + options.WithContext(c), + options.WithMetrics(metricsService), + options.WithModelLoader(modelLoader), + options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = server.App(cl, ml, options) - Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") diff --git a/core/http/apt_suite_test.go b/api/apt_suite_test.go similarity index 90% rename from core/http/apt_suite_test.go rename to api/apt_suite_test.go index 0269a97321df..e3c15c048b14 100644 --- a/core/http/apt_suite_test.go +++ b/api/apt_suite_test.go @@ -1,4 +1,4 @@ -package http_test +package api_test import ( "testing" diff --git a/core/backend/embeddings.go b/api/backend/embeddings.go similarity index 50% rename from core/backend/embeddings.go rename to api/backend/embeddings.go index 2ef4ac6b5414..63f1a831e26d 100644 --- a/core/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -2,17 +2,14 @@ package backend import ( "fmt" - "time" - "github.com/go-skynet/LocalAI/core/services" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/google/uuid" - "github.com/rs/zerolog/log" + model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } @@ -30,7 +27,6 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema. model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), - model.WithExternalBackends(o.ExternalGRPCBackends, false), }) if c.Backend == "" { @@ -94,51 +90,3 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema. return embeds, nil }, nil } - -func EmbeddingOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) { - config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) - if err != nil { - return nil, fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - items := []schema.Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions) - if err != nil { - return nil, err - } - - embeddings, err := embedFn() - if err != nil { - return nil, err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions) - if err != nil { - return nil, err - } - - embeddings, err := embedFn() - if err != nil { - return nil, err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - return &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - }, nil -} diff --git a/api/backend/image.go b/api/backend/image.go new file mode 100644 index 000000000000..6183269fd3ca --- /dev/null +++ b/api/backend/image.go @@ -0,0 +1,61 @@ +package backend + +import ( + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { + + opts := modelOpts(c, o, []model.Option{ + model.WithBackendString(c.Backend), + model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithContext(o.Context), + model.WithModel(c.Model), + model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ + CUDA: c.CUDA || c.Diffusers.CUDA, + SchedulerType: c.Diffusers.SchedulerType, + PipelineType: c.Diffusers.PipelineType, + CFGScale: c.Diffusers.CFGScale, + LoraAdapter: c.LoraAdapter, + LoraScale: c.LoraScale, + LoraBase: c.LoraBase, + IMG2IMG: c.Diffusers.IMG2IMG, + CLIPModel: c.Diffusers.ClipModel, + CLIPSubfolder: c.Diffusers.ClipSubFolder, + CLIPSkip: int32(c.Diffusers.ClipSkip), + ControlNet: c.Diffusers.ControlNet, + }), + }) + + inferenceModel, err := loader.BackendLoader( + opts..., + ) + if err != nil { + return nil, err + } + + fn := func() error { + _, err := inferenceModel.GenerateImage( + o.Context, + &proto.GenerateImageRequest{ + Height: int32(height), + Width: int32(width), + Mode: int32(mode), + Step: int32(step), + Seed: int32(seed), + CLIPSkip: int32(c.Diffusers.ClipSkip), + PositivePrompt: positive_prompt, + NegativePrompt: negative_prompt, + Dst: dst, + Src: src, + EnableParameters: c.Diffusers.EnableParameters, + }) + return err + } + + return fn, nil +} diff --git a/api/backend/llm.go b/api/backend/llm.go new file mode 100644 index 000000000000..bd320b6155ab --- /dev/null +++ b/api/backend/llm.go @@ -0,0 +1,167 @@ +package backend + +import ( + "context" + "os" + "regexp" + "strings" + "sync" + "unicode/utf8" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/grpc" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" +) + +type LLMResponse struct { + Response string // should this be []byte? + Usage TokenUsage +} + +type TokenUsage struct { + Prompt int + Completion int +} + +func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel *grpc.Client + var err error + + opts := modelOpts(c, o, []model.Option{ + model.WithLoadGRPCLoadModelOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup + model.WithAssetDir(o.AssetsDestination), + model.WithModel(modelFile), + model.WithContext(o.Context), + }) + + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) + } + + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + utils.ResetDownloadTimers() + // if we failed to load the model, we try to download it + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + if err != nil { + return nil, err + } + } + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + inferenceModel, err = loader.BackendLoader(opts...) + } + + if err != nil { + return nil, err + } + + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (LLMResponse, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + opts.Images = images + + tokenUsage := TokenUsage{} + + // check the per-model feature flag for usage, since tokenCallback may have a cost. + // Defaults to off as for now it is still experimental + if c.FeatureFlag.Enabled("usage") { + userTokenCallback := tokenCallback + if userTokenCallback == nil { + userTokenCallback = func(token string, usage TokenUsage) bool { + return true + } + } + + promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } + + tokenCallback = func(token string, usage TokenUsage) bool { + tokenUsage.Completion++ + return userTokenCallback(token, tokenUsage) + } + } + + if tokenCallback != nil { + ss := "" + + var partialRune []byte + err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { + partialRune = append(partialRune, chars...) + + for len(partialRune) > 0 { + r, size := utf8.DecodeRune(partialRune) + if r == utf8.RuneError { + // incomplete rune, wait for more bytes + break + } + + tokenCallback(string(r), tokenUsage) + ss += string(r) + + partialRune = partialRune[size:] + } + }) + return LLMResponse{ + Response: ss, + Usage: tokenUsage, + }, err + } else { + // TODO: Is the chicken bit the only way to get here? is that acceptable? + reply, err := inferenceModel.Predict(ctx, opts) + if err != nil { + return LLMResponse{}, err + } + return LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }, err + } + } + + return fn, nil +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config config.Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + + for _, c := range config.TrimSuffix { + prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c)) + } + return prediction +} diff --git a/core/backend/options.go b/api/backend/options.go similarity index 90% rename from core/backend/options.go rename to api/backend/options.go index 3e6132d63779..3266d602cce2 100644 --- a/core/backend/options.go +++ b/api/backend/options.go @@ -5,11 +5,13 @@ import ( "path/filepath" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" + model "github.com/go-skynet/LocalAI/pkg/model" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" ) -func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) []model.Option { +func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { if o.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } @@ -33,7 +35,7 @@ func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) [ return opts } -func gRPCModelOpts(c schema.Config) *pb.ModelOptions { +func gRPCModelOpts(c config.Config) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -80,7 +82,7 @@ func gRPCModelOpts(c schema.Config) *pb.ModelOptions { } } -func gRPCPredictOpts(c schema.Config, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/api/backend/transcript.go b/api/backend/transcript.go new file mode 100644 index 000000000000..77427839992a --- /dev/null +++ b/api/backend/transcript.go @@ -0,0 +1,39 @@ +package backend + +import ( + "context" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) { + + opts := modelOpts(c, o, []model.Option{ + model.WithBackendString(model.WhisperBackend), + model.WithModel(c.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + }) + + whisperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if whisperModel == nil { + return nil, fmt.Errorf("could not load whisper model") + } + + return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: audio, + Language: language, + Threads: uint32(c.Threads), + }) +} diff --git a/core/backend/tts.go b/api/backend/tts.go similarity index 76% rename from core/backend/tts.go rename to api/backend/tts.go index cd868e9a141a..ae8f53eea938 100644 --- a/core/backend/tts.go +++ b/api/backend/tts.go @@ -6,9 +6,10 @@ import ( "os" "path/filepath" + api_config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) @@ -28,19 +29,18 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *schema.StartupOptions) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend } - opts := modelOpts(schema.Config{}, o, []model.Option{ + opts := modelOpts(api_config.Config{}, o, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(o.Context), model.WithAssetDir(o.AssetsDestination), - model.WithExternalBackends(o.ExternalGRPCBackends, false), }) - piperModel, err := loader.BackendLoader(opts...) + piperModel, err := o.Loader.BackendLoader(opts...) if err != nil { return "", nil, err } @@ -60,8 +60,8 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *sch modelPath := "" if modelFile != "" { if bb != model.TransformersMusicGen { - modelPath = filepath.Join(o.ModelPath, modelFile) - if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil { + modelPath = filepath.Join(o.Loader.ModelPath, modelFile) + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { return "", nil, err } } else { diff --git a/pkg/schema/config.go b/api/config/config.go similarity index 60% rename from pkg/schema/config.go rename to api/config/config.go index 0f271ae35192..ab62841b9f22 100644 --- a/pkg/schema/config.go +++ b/api/config/config.go @@ -1,11 +1,16 @@ -package schema +package api_config import ( - "encoding/json" + "errors" "fmt" + "io/fs" "os" + "path/filepath" + "strings" + "sync" "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -147,6 +152,11 @@ type TemplateConfig struct { Functions string `yaml:"function"` } +type ConfigLoader struct { + configs map[string]Config + sync.Mutex +} + func (c *Config) SetFunctionCallString(s string) { c.functionCallString = s } @@ -183,6 +193,11 @@ func DefaultConfig(modelFile string) *Config { } } +func NewConfigLoader() *ConfigLoader { + return &ConfigLoader{ + configs: make(map[string]Config), + } +} func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) @@ -196,7 +211,7 @@ func ReadConfigFile(file string) ([]*Config, error) { return *c, nil } -func ReadSingleConfigFile(file string) (*Config, error) { +func ReadConfig(file string) (*Config, error) { c := &Config{} f, err := os.ReadFile(file) if err != nil { @@ -209,192 +224,136 @@ func ReadSingleConfigFile(file string) (*Config, error) { return c, nil } -func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.ModelBaseName != "" { - config.AutoGPTQ.ModelBaseName = input.ModelBaseName - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.UseFastTokenizer { - config.UseFastTokenizer = input.UseFastTokenizer - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar +func (cm *ConfigLoader) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) } - if input.Temperature != 0 { - config.Temperature = input.Temperature + for _, cc := range c { + cm.configs[cc.Name] = *cc } + return nil +} - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens +func (cm *ConfigLoader) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfig(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) } - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } + cm.configs[c.Name] = *c + return nil +} - if input.Keep != 0 { - config.Keep = input.Keep - } +func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} - if input.Batch != 0 { - config.Batch = input.Batch +func (cm *ConfigLoader) GetAllConfigs() []Config { + cm.Lock() + defer cm.Unlock() + var res []Config + for _, v := range cm.configs { + res = append(res, v) } + return res +} - if input.F16 { - config.F16 = input.F16 +func (cm *ConfigLoader) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) } + return res +} - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } +// Preload prepare models if they are not local but url or huggingface repositories +func (cm *ConfigLoader) Preload(modelPath string) error { + cm.Lock() + defer cm.Unlock() - if input.Seed != 0 { - config.Seed = input.Seed + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) } - if input.Mirostat != 0 { - config.LLMConfig.Mirostat = input.Mirostat - } + log.Info().Msgf("Preloading models from %s", modelPath) - if input.MirostatETA != 0 { - config.LLMConfig.MirostatETA = input.MirostatETA - } + for i, config := range cm.configs { - if input.MirostatTAU != 0 { - config.LLMConfig.MirostatTAU = input.MirostatTAU - } + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - if input.TypicalP != 0 { - config.TypicalP = input.TypicalP - } + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) + if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err } } - } - // Decode each request's message content - index := 0 - for i, m := range input.Messages { - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []Content{} - json.Unmarshal(dat, &c) - for _, pp := range c { - if pp.Type == "text" { - input.Messages[i].StringContent = pp.Text - } else if pp.Type == "image_url" { - // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := utils.GetBase64Image(pp.ImageURL.URL) - if err == nil { - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent - index++ - } else { - fmt.Print("Failed encoding image", err) - } + modelURL := config.PredictionOptions.Model + modelURL = utils.ConvertURL(modelURL) - } - } - } - } + if utils.LooksLikeURL(modelURL) { + // md5 of model name + md5Name := utils.MD5(modelURL) - // TODO: check that this was merged correctly? I _think_ it is? - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) + if err != nil { + return err } - config.InputToken = append(config.InputToken, tokens) } + + cc := cm.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + cm.configs[i] = *c } } + return nil +} - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } +func (cm *ConfigLoader) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err } - config.SetFunctionCallNameString(name) + files = append(files, info) } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + continue + } + c, err := ReadConfig(filepath.Join(path, file.Name())) + if err == nil { + cm.configs[c.Name] = *c } } + return nil } diff --git a/pkg/schema/config_test.go b/api/config/config_test.go similarity index 74% rename from pkg/schema/config_test.go rename to api/config/config_test.go index f5b7192e04c1..4b00d587eff2 100644 --- a/pkg/schema/config_test.go +++ b/api/config/config_test.go @@ -1,10 +1,11 @@ -package schema_test +package api_config_test import ( "os" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/schema" + . "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -18,7 +19,7 @@ var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") It("Test ReadConfigFile", func() { - config, err := schema.ReadConfigFile(configFile) + config, err := ReadConfigFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -27,8 +28,12 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := services.NewConfigLoader() - err := cm.LoadConfigs(os.Getenv("MODELS_PATH")) + cm := NewConfigLoader() + opts := options.NewOptions() + modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) + options.WithModelLoader(modelLoader)(opts) + + err := cm.LoadConfigs(opts.Loader.ModelPath) Expect(err).To(BeNil()) Expect(cm.ListConfigs()).ToNot(BeNil()) diff --git a/pkg/schema/prediction.go b/api/config/prediction.go similarity index 99% rename from pkg/schema/prediction.go rename to api/config/prediction.go index efd085a4ad9b..d2fbb1fa9687 100644 --- a/pkg/schema/prediction.go +++ b/api/config/prediction.go @@ -1,4 +1,4 @@ -package schema +package api_config type PredictionOptions struct { diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go new file mode 100644 index 000000000000..8cb0bb45ed14 --- /dev/null +++ b/api/localai/backend_monitor.go @@ -0,0 +1,162 @@ +package localai + +import ( + "context" + "fmt" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitorRequest struct { + Model string `json:"model" yaml:"model"` +} + +type BackendMonitorResponse struct { + MemoryInfo *gopsutil.MemoryInfoStat + MemoryPercent float32 + CPUPercent float64 +} + +type BackendMonitor struct { + configLoader *config.ConfigLoader + options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. +} + +func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor { + return BackendMonitor{ + configLoader: configLoader, + options: options, + } +} + +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetConfig(model) + var backend string + if exists { + backend = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backend = model + } + + if !strings.HasSuffix(backend, ".bin") { + backend = fmt.Sprintf("%s.bin", backend) + } + + pid, err := bm.options.Loader.GetGRPCPID(backend) + + if err != nil { + log.Error().Msgf("model %s : failed to find pid %+v", model, err) + return nil, err + } + + // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. + backendProcess, err := gopsutil.NewProcess(int32(pid)) + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) + return nil, err + } + + memInfo, err := backendProcess.MemoryInfo() + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) + return nil, err + } + + memPercent, err := backendProcess.MemoryPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) + return nil, err + } + + cpuPercent, err := backendProcess.CPUPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) + return nil, err + } + + return &BackendMonitorResponse{ + MemoryInfo: memInfo, + MemoryPercent: memPercent, + CPUPercent: cpuPercent, + }, nil +} + +func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) { + input := new(BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", err + } + + config, exists := bm.configLoader.GetConfig(input.Model) + var backendId string + if exists { + backendId = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backendId = input.Model + } + + if !strings.HasSuffix(backendId, ".bin") { + backendId = fmt.Sprintf("%s.bin", backendId) + } + + return backendId, nil +} + +func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + backendId, err := bm.getModelLoaderIDFromCtx(c) + if err != nil { + return err + } + + model := bm.options.Loader.CheckIsLoaded(backendId) + if model == "" { + return fmt.Errorf("backend %s is not currently loaded", backendId) + } + + status, rpcErr := model.GRPC(false, nil).Status(context.TODO()) + if rpcErr != nil { + log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) + val, slbErr := bm.SampleLocalBackendProcess(backendId) + if slbErr != nil { + return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) + } + return c.JSON(proto.StatusResponse{ + State: proto.StatusResponse_ERROR, + Memory: &proto.MemoryUsageData{ + Total: val.MemoryInfo.VMS, + Breakdown: map[string]uint64{ + "gopsutil-RSS": val.MemoryInfo.RSS, + }, + }, + }) + } + + return c.JSON(status) + } +} + +func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + backendId, err := bm.getModelLoaderIDFromCtx(c) + if err != nil { + return err + } + + return bm.options.Loader.ShutdownModel(backendId) + } +} diff --git a/api/localai/gallery.go b/api/localai/gallery.go new file mode 100644 index 000000000000..a2ad5bd1ac46 --- /dev/null +++ b/api/localai/gallery.go @@ -0,0 +1,326 @@ +package localai + +import ( + "context" + "fmt" + "os" + "slices" + "strings" + "sync" + + json "github.com/json-iterator/go" + "gopkg.in/yaml.v3" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +type galleryOp struct { + req gallery.GalleryModel + id string + galleries []gallery.Gallery + galleryName string +} + +type galleryOpStatus struct { + FileName string `json:"file_name"` + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` + Progress float64 `json:"progress"` + TotalFileSize string `json:"file_size"` + DownloadedFileSize string `json:"downloaded_size"` +} + +type galleryApplier struct { + modelPath string + sync.Mutex + C chan galleryOp + statuses map[string]*galleryOpStatus +} + +func NewGalleryService(modelPath string) *galleryApplier { + return &galleryApplier{ + modelPath: modelPath, + C: make(chan galleryOp), + statuses: make(map[string]*galleryOpStatus), + } +} + +func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetGalleryConfigFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) +} + +func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) { + g.Lock() + defer g.Unlock() + g.statuses[s] = op +} + +func (g *galleryApplier) getStatus(s string) *galleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses[s] +} + +func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses +} + +func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { + go func() { + for { + select { + case <-c.Done(): + return + case op := <-g.C: + utils.ResetDownloadTimers() + + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) + + // updates the status with an error + updateError := func(e error) { + g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) + } + + // displayDownload displays the download progress + progressCallback := func(fileName string, current string, total string, percentage float64) { + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + utils.DisplayDownloadFunction(fileName, current, total, percentage) + } + + var err error + // if the request contains a gallery name, we apply the gallery from the gallery list + if op.galleryName != "" { + if strings.Contains(op.galleryName, "@") { + err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } + } else { + err = prepareModel(g.modelPath, op.req, cm, progressCallback) + } + + if err != nil { + updateError(err) + continue + } + + // Reload models + err = cm.LoadConfigs(g.modelPath) + if err != nil { + updateError(err) + continue + } + + err = cm.Preload(g.modelPath) + if err != nil { + updateError(err) + continue + } + + g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) + } + } + }() +} + +type galleryModel struct { + gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 + ID string `json:"id"` +} + +func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { + var err error + for _, r := range requests { + utils.ResetDownloadTimers() + if r.ID == "" { + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { + if strings.Contains(r.ID, "@") { + err = gallery.InstallModelFromGallery( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } else { + err = gallery.InstallModelFromGalleryByName( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } + } + } + return err +} + +func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { + dat, err := os.ReadFile(s) + if err != nil { + return err + } + var requests []galleryModel + + if err := yaml.Unmarshal(dat, &requests); err != nil { + return err + } + + return processRequests(modelPath, s, cm, galleries, requests) +} + +func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { + var requests []galleryModel + err := json.Unmarshal([]byte(s), &requests) + if err != nil { + return err + } + + return processRequests(modelPath, s, cm, galleries, requests) +} + +/// Endpoint Service + +type ModelGalleryService struct { + galleries []gallery.Gallery + modelPath string + galleryApplier *galleryApplier +} + +type GalleryModel struct { + ID string `json:"id"` + gallery.GalleryModel +} + +func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService { + return ModelGalleryService{ + galleries: galleries, + modelPath: modelPath, + galleryApplier: galleryApplier, + } +} + +func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + status := mgs.galleryApplier.getStatus(c.Params("uuid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + return c.JSON(status) + } +} + +func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + return c.JSON(mgs.galleryApplier.getAllStatus()) + } +} + +func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(GalleryModel) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + mgs.galleryApplier.C <- galleryOp{ + req: input.GalleryModel, + id: uuid.String(), + galleryName: input.ID, + galleries: mgs.galleries, + } + return c.JSON(struct { + ID string `json:"uuid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} + +func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) + + models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) + if err != nil { + return err + } + log.Debug().Msgf("Models found from galleries: %+v", models) + for _, m := range models { + log.Debug().Msgf("Model found from galleries: %+v", m) + } + dat, err := json.Marshal(models) + if err != nil { + return err + } + return c.Send(dat) + } +} + +// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! +func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + return c.Send(dat) + } +} + +func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s already exists", input.Name) + } + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + log.Debug().Msgf("Adding %+v to gallery list", *input) + mgs.galleries = append(mgs.galleries, *input) + return c.Send(dat) + } +} + +func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s is not currently registered", input.Name) + } + mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) + return c.Send(nil) + } +} diff --git a/api/localai/localai.go b/api/localai/localai.go new file mode 100644 index 000000000000..c9aee2ae5c34 --- /dev/null +++ b/api/localai/localai.go @@ -0,0 +1,32 @@ +package localai + +import ( + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" +) + +type TTSRequest struct { + Model string `json:"model" yaml:"model"` + Input string `json:"input" yaml:"input"` + Backend string `json:"backend" yaml:"backend"` +} + +func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(TTSRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o) + if err != nil { + return err + } + return c.Download(filePath) + } +} diff --git a/api/openai/chat.go b/api/openai/chat.go new file mode 100644 index 000000000000..02bf6149499e --- /dev/null +++ b/api/openai/chat.go @@ -0,0 +1,399 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + emptyMessage := "" + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { + processFunctions := false + funcs := grammar.Functions{} + modelFile, input, err := readInput(c, o, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } + + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && config.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range input.Messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + MessageIndex: messageIndex, + } + templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + if toStream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { + responses := make(chan schema.OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + usage := &schema.OpenAIUsage{} + + for ev := range responses { + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + log.Debug().Msgf("Sending chunk: %s", buf.String()) + _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) + if err != nil { + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() + break + } + w.Flush() + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &schema.Message{Content: &emptyMessage}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s = utils.EscapeNewLines(s) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) + *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) + } else { + // otherwise reply with the function call + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return err + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/completion.go b/api/openai/completion.go new file mode 100644 index 000000000000..c0607632b93b --- /dev/null +++ b/api/openai/completion.go @@ -0,0 +1,199 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// https://platform.openai.com/docs/api-reference/completions +func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + + return func(c *fiber.Ctx) error { + modelFile, input, err := readInput(c, o, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("`input`: %+v", input) + + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + } + + responses := make(chan schema.OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + var result []schema.Choice + + totalTokenUsage := backend.TokenUsage{} + + for k, i := range config.PromptStrings { + if templateFile != "" { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices( + input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/edit.go b/api/openai/edit.go new file mode 100644 index 000000000000..888b9db7ffd4 --- /dev/null +++ b/api/openai/edit.go @@ -0,0 +1,94 @@ +package openai + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + + "github.com/rs/zerolog/log" +) + +func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelFile, input, err := readInput(c, o, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []schema.Choice + totalTokenUsage := backend.TokenUsage{} + + for _, i := range config.InputStrings { + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go new file mode 100644 index 000000000000..15e31e92c6eb --- /dev/null +++ b/api/openai/embeddings.go @@ -0,0 +1,78 @@ +package openai + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" + "github.com/google/uuid" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/embeddings +func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + items := []schema.Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/image.go b/api/openai/image.go new file mode 100644 index 000000000000..3e4bc349af3a --- /dev/null +++ b/api/openai/image.go @@ -0,0 +1,239 @@ +package openai + +import ( + "bufio" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/go-skynet/LocalAI/api/schema" + "github.com/google/uuid" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func downloadFile(url string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp("", "image") + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} + +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + src := "" + if input.File != "" { + + fileData := []byte{} + // check if input.File is an URL, if so download it and save it + // to a temporary file + if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { + out, err := downloadFile(input.File) + if err != nil { + return fmt.Errorf("failed downloading file:%w", err) + } + defer os.RemoveAll(out) + + fileData, err = os.ReadFile(out) + if err != nil { + return fmt.Errorf("failed reading file:%w", err) + } + + } else { + // base 64 decode the file and write it somewhere + // that we will cleanup + fileData, err = base64.StdEncoding.DecodeString(input.File) + if err != nil { + return err + } + } + + // Create a temporary file + outputFile, err := os.CreateTemp(o.ImageDir, "b64") + if err != nil { + return err + } + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(fileData) + if err != nil { + outputFile.Close() + return err + } + outputFile.Close() + src = outputFile.Name() + defer os.RemoveAll(src) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionBackend + case "tinydream": + config.Backend = model.TinyDreamBackend + case "": + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("Invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat.Type == "b64_json" { + b64JSON = true + } + // src and clip_skip + var result []schema.Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := config.Step + if step == 0 { + step = 15 + } + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = o.ImageDir + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &schema.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: result, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/inference.go b/api/openai/inference.go new file mode 100644 index 000000000000..816c960c3798 --- /dev/null +++ b/api/openai/inference.go @@ -0,0 +1,55 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices( + req *schema.OpenAIRequest, + predInput string, + config *config.Config, + o *options.Option, + loader *model.ModelLoader, + cb func(string, *[]schema.Choice), + tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { + n := req.N // number of completions to return + result := []schema.Choice{} + + if n == 0 { + n = 1 + } + + images := []string{} + for _, m := range req.Messages { + images = append(images, m.StringImages...) + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback) + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage := backend.TokenUsage{} + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage.Prompt += prediction.Usage.Prompt + tokenUsage.Completion += prediction.Usage.Completion + + finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, tokenUsage, err +} diff --git a/core/http/endpoints/openai/list.go b/api/openai/list.go similarity index 70% rename from core/http/endpoints/openai/list.go rename to api/openai/list.go index 87ef9f33e7bb..8bc5bbe22bee 100644 --- a/core/http/endpoints/openai/list.go +++ b/api/openai/list.go @@ -3,21 +3,21 @@ package openai import ( "regexp" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := ml.ListModels() + models, err := loader.ListModels() if err != nil { return err } var mm map[string]interface{} = map[string]interface{}{} - openAIModels := []schema.OpenAIModel{} + dataModels := []schema.OpenAIModel{} var filterFn func(name string) bool filter := c.Query("filter") @@ -40,13 +40,13 @@ func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(c excludeConfigured := c.QueryBool("excludeConfigured", true) // Start with the known configurations - for _, c := range cl.GetAllConfigs() { + for _, c := range cm.GetAllConfigs() { if excludeConfigured { mm[c.Model] = nil } if filterFn(c.Name) { - openAIModels = append(openAIModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) } } @@ -54,7 +54,7 @@ func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(c for _, m := range models { // And only adds them if they shouldn't be skipped. if _, exists := mm[m]; !exists && filterFn(m) { - openAIModels = append(openAIModels, schema.OpenAIModel{ID: m, Object: "model"}) + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) } } @@ -63,7 +63,7 @@ func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(c Data []schema.OpenAIModel `json:"data"` }{ Object: "list", - Data: openAIModels, + Data: dataModels, }) } } diff --git a/api/openai/request.go b/api/openai/request.go new file mode 100644 index 000000000000..cc15fe409c27 --- /dev/null +++ b/api/openai/request.go @@ -0,0 +1,336 @@ +package openai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + options "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) { + loader := o.Loader + input := new(schema.OpenAIRequest) + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, fmt.Errorf("failed parsing request body: %w", err) + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists && randomModel { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return "", nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + return modelFile, input, nil +} + +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string +func getBase64Image(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := http.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} + +func updateConfig(config *config.Config, input *schema.OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.ModelBaseName != "" { + config.AutoGPTQ.ModelBaseName = input.ModelBaseName + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.UseFastTokenizer { + config.UseFastTokenizer = input.UseFastTokenizer + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + // Decode each request's message content + index := 0 + for i, m := range input.Messages { + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + for _, pp := range c { + if pp.Type == "text" { + input.Messages[i].StringContent = pp.Text + } else if pp.Type == "image_url" { + // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: + base64, err := getBase64Image(pp.ImageURL.URL) + if err == nil { + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent + index++ + } else { + fmt.Print("Failed encoding image", err) + } + } + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.LLMConfig.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.LLMConfig.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.LLMConfig.MirostatTAU = input.MirostatTAU + } + + if input.TypicalP != 0 { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) { + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + + var cfg *config.Config + + defaults := func() { + cfg = config.DefaultConfig(modelFile) + cfg.ContextSize = ctx + cfg.Threads = threads + cfg.F16 = f16 + cfg.Debug = debug + } + + cfgExisting, exists := cm.GetConfig(modelFile) + if !exists { + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cm.GetConfig(modelFile) + if exists { + cfg = &cfgExisting + } else { + defaults() + } + } else { + defaults() + } + } else { + cfg = &cfgExisting + } + + // Set the parameters for the language model prediction + updateConfig(cfg, input) + + // Don't allow 0 as setting + if cfg.Threads == 0 { + if threads != 0 { + cfg.Threads = threads + } else { + cfg.Threads = 4 + } + } + + // Enforce debug flag if passed from CLI + if debug { + cfg.Debug = true + } + + return cfg, input, nil +} diff --git a/api/openai/transcription.go b/api/openai/transcription.go new file mode 100644 index 000000000000..895c110f5df4 --- /dev/null +++ b/api/openai/transcription.go @@ -0,0 +1,71 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/audio/create +func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + // retrieve the file data from the request + file, err := c.FormFile("file") + if err != nil { + return err + } + f, err := file.Open() + if err != nil { + return err + } + defer f.Close() + + dir, err := os.MkdirTemp("", "whisper") + + if err != nil { + return err + } + defer os.RemoveAll(dir) + + dst := filepath.Join(dir, path.Base(file.Filename)) + dstFile, err := os.Create(dst) + if err != nil { + return err + } + + if _, err := io.Copy(dstFile, f); err != nil { + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) + return err + } + + log.Debug().Msgf("Audio file copied to: %+v", dst) + + tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) + if err != nil { + return err + } + + log.Debug().Msgf("Trascribed: %+v", tr) + // TODO: handle different outputs here + return c.Status(http.StatusOK).JSON(tr) + } +} diff --git a/pkg/schema/startup_options.go b/api/options/options.go similarity index 68% rename from pkg/schema/startup_options.go rename to api/options/options.go index be6d2e9a6400..e83eaaad8c19 100644 --- a/pkg/schema/startup_options.go +++ b/api/options/options.go @@ -1,4 +1,4 @@ -package schema +package options import ( "context" @@ -6,14 +6,16 @@ import ( "encoding/json" "time" + "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) -type StartupOptions struct { +type Option struct { Context context.Context ConfigFile string - ModelPath string + Loader *model.ModelLoader UploadLimitMB, Threads, ContextSize int F16 bool Debug, DisableMessage bool @@ -24,7 +26,7 @@ type StartupOptions struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string - Metrics *LocalAIMetrics + Metrics *metrics.Metrics Galleries []gallery.Gallery @@ -45,14 +47,12 @@ type StartupOptions struct { ModelsURL []string WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration - - LocalAIConfigDir string } -type AppOption func(*StartupOptions) +type AppOption func(*Option) -func NewStartupOptions(o ...AppOption) *StartupOptions { - opt := &StartupOptions{ +func NewOptions(o ...AppOption) *Option { + opt := &Option{ Context: context.Background(), UploadLimitMB: 15, Threads: 1, @@ -67,57 +67,57 @@ func NewStartupOptions(o ...AppOption) *StartupOptions { } func WithModelsURL(urls ...string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.ModelsURL = urls } } func WithCors(b bool) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.CORS = b } } -var EnableWatchDog = func(o *StartupOptions) { +var EnableWatchDog = func(o *Option) { o.WatchDog = true } -var EnableWatchDogIdleCheck = func(o *StartupOptions) { +var EnableWatchDogIdleCheck = func(o *Option) { o.WatchDog = true o.WatchDogIdle = true } -var EnableWatchDogBusyCheck = func(o *StartupOptions) { +var EnableWatchDogBusyCheck = func(o *Option) { o.WatchDog = true o.WatchDogBusy = true } func SetWatchDogBusyTimeout(t time.Duration) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.WatchDogBusyTimeout = t } } func SetWatchDogIdleTimeout(t time.Duration) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.WatchDogIdleTimeout = t } } -var EnableSingleBackend = func(o *StartupOptions) { +var EnableSingleBackend = func(o *Option) { o.SingleBackend = true } -var EnableParallelBackendRequests = func(o *StartupOptions) { +var EnableParallelBackendRequests = func(o *Option) { o.ParallelBackendRequests = true } -var EnableGalleriesAutoload = func(o *StartupOptions) { +var EnableGalleriesAutoload = func(o *Option) { o.AutoloadGalleries = true } func WithExternalBackend(name string, uri string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { if o.ExternalGRPCBackends == nil { o.ExternalGRPCBackends = make(map[string]string) } @@ -126,25 +126,25 @@ func WithExternalBackend(name string, uri string) AppOption { } func WithCorsAllowOrigins(b string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.BackendAssets = f } } func WithStringGalleries(galls string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { if galls == "" { log.Debug().Msgf("no galleries to load") o.Galleries = []gallery.Gallery{} @@ -159,102 +159,96 @@ func WithStringGalleries(galls string) AppOption { } func WithGalleries(galleries []gallery.Gallery) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.ConfigFile = configFile } } -func WithModelPath(path string) AppOption { - return func(o *StartupOptions) { - o.ModelPath = path +func WithModelLoader(loader *model.ModelLoader) AppOption { + return func(o *Option) { + o.Loader = loader } } func WithUploadLimitMB(limit int) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.F16 = f16 } } func WithDebug(debug bool) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.ImageDir = imageDir } } func WithApiKeys(apiKeys []string) AppOption { - return func(o *StartupOptions) { + return func(o *Option) { o.ApiKeys = apiKeys } } -func WithMetrics(metrics *LocalAIMetrics) AppOption { - return func(o *StartupOptions) { - o.Metrics = metrics - } -} - -func WithLocalAIConfigDir(configDir string) AppOption { - return func(o *StartupOptions) { - o.LocalAIConfigDir = configDir +func WithMetrics(meter *metrics.Metrics) AppOption { + return func(o *Option) { + o.Metrics = meter } } diff --git a/pkg/schema/openai.go b/api/schema/openai.go similarity index 97% rename from pkg/schema/openai.go rename to api/schema/openai.go index 32ecbd8bef66..6355ff63d5e2 100644 --- a/pkg/schema/openai.go +++ b/api/schema/openai.go @@ -3,6 +3,8 @@ package schema import ( "context" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/grammar" ) @@ -88,7 +90,7 @@ type ChatCompletionResponseFormat struct { } type OpenAIRequest struct { - PredictionOptions + config.PredictionOptions Context context.Context Cancel context.CancelFunc diff --git a/pkg/schema/whisper.go b/api/schema/whisper.go similarity index 60% rename from pkg/schema/whisper.go rename to api/schema/whisper.go index 7225980f577b..41413c1f06ed 100644 --- a/pkg/schema/whisper.go +++ b/api/schema/whisper.go @@ -2,7 +2,7 @@ package schema import "time" -type WhisperSegment struct { +type Segment struct { Id int `json:"id"` Start time.Duration `json:"start"` End time.Duration `json:"end"` @@ -10,7 +10,7 @@ type WhisperSegment struct { Tokens []int `json:"tokens"` } -type WhisperResult struct { - Segments []WhisperSegment `json:"segments"` - Text string `json:"text"` +type Result struct { + Segments []Segment `json:"segments"` + Text string `json:"text"` } diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index 6d1ba3da608c..ebd43eca6b84 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -8,7 +8,7 @@ import ( "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/go-audio/wav" - "github.com/go-skynet/LocalAI/pkg/schema" + "github.com/go-skynet/LocalAI/api/schema" ) func sh(c string) (string, error) { @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.WhisperResult, error) { - res := schema.WhisperResult{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { + res := schema.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -90,7 +90,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( tokens = append(tokens, t.Id) } - segment := schema.WhisperSegment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} + segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} res.Segments = append(res.Segments, segment) res.Text += s.Text diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index 6336589eb220..a033afb0cdbc 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -4,9 +4,9 @@ package main // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/schema" ) type Whisper struct { @@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.WhisperResult, error) { +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/config/.keep b/config/.keep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/core/backend/image.go b/core/backend/image.go deleted file mode 100644 index caaa8f38d9be..000000000000 --- a/core/backend/image.go +++ /dev/null @@ -1,210 +0,0 @@ -package backend - -import ( - "encoding/base64" - "fmt" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() error, error) { - - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(c.Backend), - model.WithAssetDir(o.AssetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithContext(o.Context), - model.WithModel(c.Model), - model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ - CUDA: c.CUDA || c.Diffusers.CUDA, - SchedulerType: c.Diffusers.SchedulerType, - PipelineType: c.Diffusers.PipelineType, - CFGScale: c.Diffusers.CFGScale, - LoraAdapter: c.LoraAdapter, - LoraScale: c.LoraScale, - LoraBase: c.LoraBase, - IMG2IMG: c.Diffusers.IMG2IMG, - CLIPModel: c.Diffusers.ClipModel, - CLIPSubfolder: c.Diffusers.ClipSubFolder, - CLIPSkip: int32(c.Diffusers.ClipSkip), - ControlNet: c.Diffusers.ControlNet, - }), - model.WithExternalBackends(o.ExternalGRPCBackends, false), - }) - - inferenceModel, err := loader.BackendLoader( - opts..., - ) - if err != nil { - return nil, err - } - - fn := func() error { - _, err := inferenceModel.GenerateImage( - o.Context, - &proto.GenerateImageRequest{ - Height: int32(height), - Width: int32(width), - Mode: int32(mode), - Step: int32(step), - Seed: int32(seed), - CLIPSkip: int32(c.Diffusers.ClipSkip), - PositivePrompt: positive_prompt, - NegativePrompt: negative_prompt, - Dst: dst, - Src: src, - EnableParameters: c.Diffusers.EnableParameters, - }) - return err - } - - return fn, nil -} - -func ImageGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) { - id := uuid.New().String() - created := int(time.Now().Unix()) - - if modelName == "" { - modelName = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", modelName) - - config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) - if err != nil { - return nil, fmt.Errorf("failed reading parameters from request: %w", err) - } - - src := "" - if input.File != "" { - if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { - src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src") - if err != nil { - return nil, fmt.Errorf("failed downloading file:%w", err) - } - } else { - src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src") - if err != nil { - return nil, fmt.Errorf("error creating temporary image source file: %w", err) - } - } - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - switch config.Backend { - case "stablediffusion": - config.Backend = model.StableDiffusionBackend - case "tinydream": - config.Backend = model.TinyDreamBackend - case "": - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return nil, fmt.Errorf("invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return nil, fmt.Errorf("invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return nil, fmt.Errorf("invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat.Type == "b64_json" { - b64JSON = true - } - // src and clip_skip - var result []schema.Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := config.Step - if step == 0 { - step = 15 - } - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = startupOptions.ImageDir - } - // Create a temporary file - outputFile, err := os.CreateTemp(tempDir, "b64") - if err != nil { - return nil, err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return nil, err - } - - fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions) - if err != nil { - return nil, err - } - if err := fn(); err != nil { - return nil, err - } - - item := &schema.Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return nil, err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = path.Join(startupOptions.ImageDir, base) - } - - result = append(result, *item) - } - } - - return &schema.OpenAIResponse{ - ID: id, - Created: created, - Data: result, - }, nil -} diff --git a/core/backend/llm.go b/core/backend/llm.go deleted file mode 100644 index bf2db46b5052..000000000000 --- a/core/backend/llm.go +++ /dev/null @@ -1,861 +0,0 @@ -package backend - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - "time" - "unicode/utf8" - - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/grammar" - "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -////////// TYPES ////////////// - -type LLMResponse struct { - Response string // should this be []byte? - Usage TokenUsage -} - -// TODO: Test removing this and using the variant in pkg/schema someday? -type TokenUsage struct { - Prompt int - Completion int -} - -type TemplateConfigBindingFn func(*schema.Config) *string - -// type LLMStreamProcessor func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) - -/////// CONSTS /////////// - -const DEFAULT_NO_ACTION_NAME = "answer" -const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action" - -////// INFERENCE ///////// - -func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel *grpc.Client - var err error - - opts := modelOpts(c, o, []model.Option{ - model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(o.AssetsDestination), - model.WithModel(modelFile), - model.WithContext(o.Context), - model.WithExternalBackends(o.ExternalGRPCBackends, false), - }) - - if c.Backend != "" { - opts = append(opts, model.WithBackendString(c.Backend)) - } - - // Check if the modelFile exists, if it doesn't try to load it from the gallery - if o.AutoloadGalleries { // experimental - if _, err := os.Stat(modelFile); os.IsNotExist(err) { - utils.ResetDownloadTimers() - // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) - if err != nil { - return nil, err - } - } - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - inferenceModel, err = loader.BackendLoader(opts...) - } - - if err != nil { - return nil, err - } - - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - fn := func() (LLMResponse, error) { - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - opts.Images = images - - tokenUsage := TokenUsage{} - - // check the per-model feature flag for usage, since tokenCallback may have a cost. - // Defaults to off as for now it is still experimental - if c.FeatureFlag.Enabled("usage") { - userTokenCallback := tokenCallback - if userTokenCallback == nil { - userTokenCallback = func(token string, usage TokenUsage) bool { - return true - } - } - - promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) - if pErr == nil && promptInfo.Length > 0 { - tokenUsage.Prompt = int(promptInfo.Length) - } - - tokenCallback = func(token string, usage TokenUsage) bool { - tokenUsage.Completion++ - return userTokenCallback(token, tokenUsage) - } - } - - if tokenCallback != nil { - ss := "" - - var partialRune []byte - err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { - partialRune = append(partialRune, chars...) - - for len(partialRune) > 0 { - r, size := utf8.DecodeRune(partialRune) - if r == utf8.RuneError { - // incomplete rune, wait for more bytes - break - } - - tokenCallback(string(r), tokenUsage) - ss += string(r) - - partialRune = partialRune[size:] - } - }) - return LLMResponse{ - Response: ss, - Usage: tokenUsage, - }, err - } else { - // TODO: Is the chicken bit the only way to get here? is that acceptable? - reply, err := inferenceModel.Predict(ctx, opts) - if err != nil { - return LLMResponse{}, err - } - return LLMResponse{ - Response: string(reply.Message), - Usage: tokenUsage, - }, err - } - } - - return fn, nil -} - -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} - -func Finetune(config schema.Config, input, prediction string) string { - if config.Echo { - prediction = input + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - - for _, c := range config.TrimSuffix { - prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c)) - } - return prediction - -} - -////// CONFIG AND REQUEST HANDLING /////////////// - -func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *schema.OpenAIRequest, cm *services.ConfigLoader, startupOptions *schema.StartupOptions) (*schema.Config, *schema.OpenAIRequest, error) { - // Load a config file if present after the model name - modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml") - - var cfg *schema.Config - - defaults := func() { - cfg = schema.DefaultConfig(modelFile) - cfg.ContextSize = startupOptions.ContextSize - cfg.Threads = startupOptions.Threads - cfg.F16 = startupOptions.F16 - cfg.Debug = startupOptions.Debug - } - - cfgExisting, exists := cm.GetConfig(modelFile) - if !exists { - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = cm.GetConfig(modelFile) - if exists { - cfg = &cfgExisting - } else { - defaults() - } - } else { - defaults() - } - } else { - cfg = &cfgExisting - } - - // Set the parameters for the language model prediction - schema.UpdateConfigFromOpenAIRequest(cfg, input) - - // Don't allow 0 as setting - if cfg.Threads == 0 { - if startupOptions.Threads != 0 { - cfg.Threads = startupOptions.Threads - } else { - cfg.Threads = 4 - } - } - - // Enforce debug flag if passed from CLI - if startupOptions.Debug { - cfg.Debug = true - } - - return cfg, input, nil -} - -func ComputeChoices( - req *schema.OpenAIRequest, - predInput string, - config *schema.Config, - o *schema.StartupOptions, - loader *model.ModelLoader, - cb func(string, *[]schema.Choice), - tokenCallback func(string, TokenUsage) bool) ([]schema.Choice, TokenUsage, error) { - n := req.N // number of completions to return - result := []schema.Choice{} - - if n == 0 { - n = 1 - } - - images := []string{} - for _, m := range req.Messages { - images = append(images, m.StringImages...) - } - - // get the model function to call for the result - predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback) - if err != nil { - return result, TokenUsage{}, err - } - - tokenUsage := TokenUsage{} - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, TokenUsage{}, err - } - - tokenUsage.Prompt += prediction.Usage.Prompt - tokenUsage.Completion += prediction.Usage.Completion - - finetunedResponse := Finetune(*config, predInput, prediction.Response) - cb(finetunedResponse, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, tokenUsage, err -} - -// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below? -func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, error) { - config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) - if err != nil { - return nil, fmt.Errorf("failed reading parameters from request:%w", err) - } - - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - configTemplate := bindingFn(config) - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) { - *configTemplate = config.Model - } - if *configTemplate == "" { - return nil, fmt.Errorf(("failed to find templateConfig")) - } - - return config, nil -} - -////////// SPECIFIC REQUESTS ////////////// -// TODO: For round one of the refactor, give each of the three primary text endpoints their own function? -// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split -// Can cleanup into a common form later if possible easier if they are all here for now -// If they remain different, extract each of these named segments to a seperate file - -func prepareChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, string, bool, error) { - - // IMPORTANT DEFS - funcs := grammar.Functions{} - - // The Basic Begining - - config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) - if err != nil { - return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Special Input/Config Handling - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - but if they are missing, use defaults. - if config.FunctionsConfig.NoActionFunctionName == "" { - config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME - } - if config.FunctionsConfig.NoActionDescriptionName == "" { - config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION - } - - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } - - processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions() - - if processFunctions { - log.Debug().Msgf("Response needs to process functions") - - noActionGrammar := grammar.Function{ - Name: config.FunctionsConfig.NoActionFunctionName, - Description: config.FunctionsConfig.NoActionDescriptionName, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.FunctionToCall() != "" { - funcs = funcs.Select(config.FunctionToCall()) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") - } - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range input.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if i.FunctionCall != nil && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - // First attempt to populate content via a chat message specific template - if config.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: config.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - MessageIndex: messageIndex, - } - templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - - return config, predInput, processFunctions, nil - -} - -func EditGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) { - id := uuid.New().String() - created := int(time.Now().Unix()) - - binding := func(config *schema.Config) *string { - return &config.TemplateConfig.Edit - } - - config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) - if err != nil { - return nil, err - } - - var result []schema.Choice - totalTokenUsage := TokenUsage{} - - for _, i := range config.InputStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{ - Input: i, - Instruction: input.Instruction, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s}) - }, nil) - if err != nil { - return nil, err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) - } - - return &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - }, nil -} - -func ChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) { - - // DEFS - id := uuid.New().String() - created := int(time.Now().Unix()) - - // Prepare - config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) - if err != nil { - return nil, err - } - - result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s = utils.EscapeNewLines(s) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == config.FunctionsConfig.NoActionFunctionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) - return - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - fineTunedResponse := Finetune(*config, predInput, prediction.Response) - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) - } else { - // otherwise reply with the function call - *c = append(*c, schema.Choice{ - FinishReason: "function_call", - Message: &schema.Message{Role: "assistant", FunctionCall: ss}, - }) - } - - return - } - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return nil, err - } - - return &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, - }, nil - -} - -func CompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) { - // Prepare - id := uuid.New().String() - created := int(time.Now().Unix()) - - binding := func(config *schema.Config) *string { - return &config.TemplateConfig.Completion - } - - config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) - if err != nil { - return nil, err - } - - var result []schema.Choice - - totalTokenUsage := TokenUsage{} - - for k, i := range config.PromptStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, tokenUsage, err := ComputeChoices( - input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) - }, nil) - if err != nil { - return nil, err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) - } - - return &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - }, nil -} - -func StreamingChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) { - - // DEFS - emptyMessage := "" - id := uuid.New().String() - created := int(time.Now().Unix()) - - // Prepare - config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) - if err != nil { - return nil, err - } - - if processFunctions { - // TODO: unused variable means I did something wrong. investigate once stable - log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name) - } - - processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - - responses <- resp - return true - }) - close(responses) - } - log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel") - - responses := make(chan schema.OpenAIResponse) - - log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine") - - go processor(predInput, input, config, ml, responses) - - log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!") - - return responses, nil - -} - -func StreamingCompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) { - // DEFS - id := uuid.New().String() - created := int(time.Now().Unix()) - - binding := func(config *schema.Config) *string { - return &config.TemplateConfig.Completion - } - - // Prepare - - config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) - if err != nil { - return nil, err - } - - processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - - if len(config.PromptStrings) > 1 { - return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - - } - - predInput := config.PromptStrings[0] - - //A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - - log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel") - - responses := make(chan schema.OpenAIResponse) - - log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine") - - go processor(predInput, input, config, ml, responses) - - log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!") - - return responses, nil -} diff --git a/core/backend/transcription.go b/core/backend/transcription.go deleted file mode 100644 index 449b73230fb4..000000000000 --- a/core/backend/transcription.go +++ /dev/null @@ -1,52 +0,0 @@ -package backend - -import ( - "context" - "fmt" - - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" -) - -func ModelTranscription(audio, language string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (*schema.WhisperResult, error) { - - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(model.WhisperBackend), - model.WithModel(c.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), - model.WithExternalBackends(o.ExternalGRPCBackends, false), - }) - - whisperModel, err := loader.BackendLoader(opts...) - if err != nil { - return nil, err - } - - if whisperModel == nil { - return nil, fmt.Errorf("could not load whisper model") - } - - return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: audio, - Language: language, - Threads: uint32(c.Threads), - }) -} - -func TranscriptionOpenAIRequest(modelName string, input *schema.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.WhisperResult, error) { - config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) - if err != nil { - return nil, fmt.Errorf("failed reading parameters from request:%w", err) - } - - tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions) - if err != nil { - return nil, err - } - - return tr, nil -} diff --git a/core/http/api.go b/core/http/api.go deleted file mode 100644 index 8cf2c3be2f31..000000000000 --- a/core/http/api.go +++ /dev/null @@ -1,169 +0,0 @@ -package http - -import ( - "errors" - "strings" - - "github.com/go-skynet/LocalAI/core/http/endpoints/localai" - "github.com/go-skynet/LocalAI/core/http/endpoints/openai" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/recover" -) - -func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*fiber.App, error) { - - // Return errors as JSON responses - app := fiber.New(fiber.Config{ - BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.DisableMessage, - // Override default error handler - ErrorHandler: func(ctx *fiber.Ctx, err error) error { - // Status code defaults to 500 - code := fiber.StatusInternalServerError - - // Retrieve the custom status code if it's a *fiber.Error - var e *fiber.Error - if errors.As(err, &e) { - code = e.Code - } - - // Send custom error page - return ctx.Status(code).JSON( - schema.ErrorResponse{ - Error: &schema.APIError{Message: err.Error(), Code: code}, - }, - ) - }, - }) - - if options.Debug { - app.Use(logger.New(logger.Config{ - Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", - })) - } - - // Default middleware config - app.Use(recover.New()) - - if options.Metrics != nil { - app.Use(localai.MetricsAPIMiddleware(options.Metrics)) - } - - // Auth middleware checking if API key is valid. If no API key is set, no auth is required. - auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) == 0 { - return c.Next() - } - - authHeader := c.Get("Authorization") - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) - } - - apiKey := authHeaderParts[1] - for _, key := range options.ApiKeys { - if apiKey == key { - return c.Next() - } - } - - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) - - } - - if options.CORS { - var c func(ctx *fiber.Ctx) error - if options.CORSAllowOrigins == "" { - c = cors.New() - } else { - c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) - } - - app.Use(c) - } - - // LocalAI API endpoints - galleryService := services.NewGalleryApplier(options.ModelPath) - galleryService.Start(options.Context, cl) - - app.Get("/version", auth, func(c *fiber.Ctx) error { - return c.JSON(struct { - Version string `json:"version"` - }{Version: internal.PrintableVersion()}) - }) - - modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) - - // openAI compatible API endpoint - - // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) - - // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options)) - app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options)) - - // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options)) - - // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) - - // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options)) - - // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options)) - - if options.ImageDir != "" { - app.Static("/generated-images", options.ImageDir) - } - - if options.AudioDir != "" { - app.Static("/generated-audio", options.AudioDir) - } - - ok := func(c *fiber.Ctx) error { - return c.SendStatus(200) - } - - // Kubernetes health checks - app.Get("/healthz", ok) - app.Get("/readyz", ok) - - app.Get("/metrics", localai.MetricsHandler()) - - backendMonitor := services.NewBackendMonitor(cl, ml, options) - app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) - app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) - - // model listing - app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) - - return app, nil -} diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go deleted file mode 100644 index db99c54ef7ac..000000000000 --- a/core/http/endpoints/localai/backend_monitor.go +++ /dev/null @@ -1,34 +0,0 @@ -package localai - -import ( - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" -) - -func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(schema.BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - resp, err := bm.CheckAndSample(input.Model) - if err != nil { - return err - } - return c.JSON(resp) - } -} - -func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(schema.BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - return bm.ShutdownModel(input.Model) - } -} diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go deleted file mode 100644 index 6b4d73b596e0..000000000000 --- a/core/http/endpoints/localai/gallery.go +++ /dev/null @@ -1,148 +0,0 @@ -package localai - -import ( - "encoding/json" - "fmt" - "slices" - - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -/// Endpoint Service - -type ModelGalleryEndpointService struct { - galleries []gallery.Gallery - modelPath string - galleryApplier *services.GalleryApplier -} - -type GalleryModel struct { - ID string `json:"id"` - gallery.GalleryModel -} - -func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService { - return ModelGalleryEndpointService{ - galleries: galleries, - modelPath: modelPath, - galleryApplier: galleryApplier, - } -} - -func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - status := mgs.galleryApplier.GetStatus(c.Params("uuid")) - if status == nil { - return fmt.Errorf("could not find any status for ID") - } - return c.JSON(status) - } -} - -func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - return c.JSON(mgs.galleryApplier.GetAllStatus()) - } -} - -func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(GalleryModel) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - uuid, err := uuid.NewUUID() - if err != nil { - return err - } - mgs.galleryApplier.C <- gallery.GalleryOp{ - Req: input.GalleryModel, - Id: uuid.String(), - GalleryName: input.ID, - Galleries: mgs.galleries, - } - return c.JSON(struct { - ID string `json:"uuid"` - StatusURL string `json:"status"` - }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) - } -} - -func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) - - models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) - if err != nil { - return err - } - log.Debug().Msgf("Models found from galleries: %+v", models) - for _, m := range models { - log.Debug().Msgf("Model found from galleries: %+v", m) - } - dat, err := json.Marshal(models) - if err != nil { - return err - } - return c.Send(dat) - } -} - -// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! -func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - return c.Send(dat) - } -} - -func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s already exists", input.Name) - } - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - log.Debug().Msgf("Adding %+v to gallery list", *input) - mgs.galleries = append(mgs.galleries, *input) - return c.Send(dat) - } -} - -func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s is not currently registered", input.Name) - } - mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) - return c.Send(nil) - } -} diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go deleted file mode 100644 index 44db1897f4df..000000000000 --- a/core/http/endpoints/localai/metrics.go +++ /dev/null @@ -1,42 +0,0 @@ -package localai - -import ( - "time" - - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/adaptor" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -func MetricsHandler() fiber.Handler { - return adaptor.HTTPHandler(promhttp.Handler()) -} - -type apiMiddlewareConfig struct { - Filter func(c *fiber.Ctx) bool - metrics *schema.LocalAIMetrics -} - -func MetricsAPIMiddleware(metrics *schema.LocalAIMetrics) fiber.Handler { - cfg := apiMiddlewareConfig{ - metrics: metrics, - Filter: func(c *fiber.Ctx) bool { - return c.Path() == "/metrics" - }, - } - - return func(c *fiber.Ctx) error { - if cfg.Filter != nil && cfg.Filter(c) { - return c.Next() - } - path := c.Path() - method := c.Method() - - start := time.Now() - err := c.Next() - elapsed := float64(time.Since(start)) / float64(time.Second) - cfg.metrics.ObserveAPICall(method, path, elapsed) - return err - } -} diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go deleted file mode 100644 index bf4904534dd6..000000000000 --- a/core/http/endpoints/localai/tts.go +++ /dev/null @@ -1,25 +0,0 @@ -package localai - -import ( - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" -) - -func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(schema.TTSRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so) - if err != nil { - return err - } - return c.Download(filePath) - } -} diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go deleted file mode 100644 index 8ea38f62a697..000000000000 --- a/core/http/endpoints/openai/chat.go +++ /dev/null @@ -1,97 +0,0 @@ -package openai - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) func(c *fiber.Ctx) error { - - emptyMessage := "" - - return func(c *fiber.Ctx) error { - modelName, input, err := readInput(c, startupOptions, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - // The scary comment I feel like I forgot about along the way: - // - // functions are not supported in stream mode (yet?) - // - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - - responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) - if err != nil { - return fmt.Errorf("failed establishing streaming chat request :%w", err) - } - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - usage := &schema.OpenAIUsage{} - id := "" - created := 0 - for ev := range responses { - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - id = ev.ID - created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) - _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) - if err != nil { - log.Debug().Msgf("Sending chunk failed: %v", err) - input.Cancel() - break - } - w.Flush() - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - FinishReason: "stop", - Index: 0, - Delta: &schema.Message{Content: &emptyMessage}, - }}, - Object: "chat.completion.chunk", - Usage: *usage, - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - ////////////////////////////////////////// - - resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) - if err != nil { - return fmt.Errorf("error generating chat request: +%w", err) - } - respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this? - log.Debug().Msgf("Response: %s", respData) - return c.JSON(resp) - } -} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go deleted file mode 100644 index bda90b2162ec..000000000000 --- a/core/http/endpoints/openai/completion.go +++ /dev/null @@ -1,91 +0,0 @@ -package openai - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "time" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -// https://platform.openai.com/docs/api-reference/completions -func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - id := uuid.New().String() - created := int(time.Now().Unix()) - - return func(c *fiber.Ctx) error { - modelName, input, err := readInput(c, so, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("`input`: %+v", input) - - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - - responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) - if err != nil { - return fmt.Errorf("failed establishing streaming completion request :%w", err) - } - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - FinishReason: "stop", - }, - }, - Object: "text_completion", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - /////////// - - resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) - if err != nil { - return fmt.Errorf("error generating completion request: +%w", err) - } - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go deleted file mode 100644 index aa85239c7c58..000000000000 --- a/core/http/endpoints/openai/edit.go +++ /dev/null @@ -1,34 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - - "github.com/rs/zerolog/log" -) - -func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, so, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so) - if err != nil { - return err - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go deleted file mode 100644 index d20f1605f0c4..000000000000 --- a/core/http/endpoints/openai/embeddings.go +++ /dev/null @@ -1,35 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -// https://platform.openai.com/docs/api-reference/embeddings -func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, so, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so) - if err != nil { - return err - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go deleted file mode 100644 index e571c8e74a64..000000000000 --- a/core/http/endpoints/openai/image.go +++ /dev/null @@ -1,48 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -// https://platform.openai.com/docs/api-reference/images/create - -/* -* - - curl http://localhost:8080/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cute baby sea otter", - "n": 1, - "size": "512x512" - }' - -* -*/ -func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelName, input, err := readInput(c, so, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so) - if err != nil { - return fmt.Errorf("error generating image request: +%w", err) - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go deleted file mode 100644 index aa9b882f851e..000000000000 --- a/core/http/endpoints/openai/request.go +++ /dev/null @@ -1,57 +0,0 @@ -package openai - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -func readInput(c *fiber.Ctx, o *schema.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) - ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx - input.Cancel = cancel - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && ml.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists && randomModel { - models, _ := ml.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - return modelFile, input, nil -} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go deleted file mode 100644 index 0d9de3ba6032..000000000000 --- a/core/http/endpoints/openai/transcription.go +++ /dev/null @@ -1,49 +0,0 @@ -package openai - -import ( - "fmt" - "net/http" - "os" - "path" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/go-skynet/LocalAI/pkg/utils" - - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -// https://platform.openai.com/docs/api-reference/audio/create -func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelName, input, err := readInput(c, so, ml, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - // retrieve the file data from the request - file, err := c.FormFile("file") - if err != nil { - return err - } - - dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper - if err != nil { - return err - } - - log.Debug().Msgf("Audio file copied to: %+v", dst) - defer os.RemoveAll(path.Dir(dst)) - - tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so) - if err != nil { - return fmt.Errorf("error generating transcription request: +%w", err) - } - log.Debug().Msgf("Trascribed: %+v", tr) - // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) - } -} diff --git a/core/mqtt/manager.go b/core/mqtt/manager.go deleted file mode 100644 index d8e096abbd63..000000000000 --- a/core/mqtt/manager.go +++ /dev/null @@ -1,24 +0,0 @@ -package mqtt - -import ( - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" -) - -// PLACEHOLDER DURING PART 1 OF THE REFACTOR - -type MQTTManager struct { - configLoader *services.ConfigLoader - modelLoader *model.ModelLoader - startupOptions *schema.StartupOptions -} - -func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*MQTTManager, error) { - - return &MQTTManager{ - configLoader: cl, - modelLoader: ml, - startupOptions: options, - }, nil -} diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go deleted file mode 100644 index e1c88283df3e..000000000000 --- a/core/services/backend_monitor.go +++ /dev/null @@ -1,138 +0,0 @@ -package services - -import ( - "context" - "fmt" - "strings" - - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/rs/zerolog/log" - - gopsutil "github.com/shirou/gopsutil/v3/process" -) - -type BackendMonitor struct { - configLoader *ConfigLoader - modelLoader *model.ModelLoader - options *schema.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. -} - -func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *schema.StartupOptions) *BackendMonitor { - return &BackendMonitor{ - configLoader: configLoader, - modelLoader: modelLoader, - options: options, - } -} - -func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bm.configLoader.GetConfig(model) - var backend string - if exists { - backend = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backend = model - } - - if !strings.HasSuffix(backend, ".bin") { - backend = fmt.Sprintf("%s.bin", backend) - } - - pid, err := bm.modelLoader.GetGRPCPID(backend) - - if err != nil { - log.Error().Msgf("model %s : failed to find pid %+v", model, err) - return nil, err - } - - // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. - backendProcess, err := gopsutil.NewProcess(int32(pid)) - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) - return nil, err - } - - memInfo, err := backendProcess.MemoryInfo() - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) - return nil, err - } - - memPercent, err := backendProcess.MemoryPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) - return nil, err - } - - cpuPercent, err := backendProcess.CPUPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) - return nil, err - } - - return &schema.BackendMonitorResponse{ - MemoryInfo: memInfo, - MemoryPercent: memPercent, - CPUPercent: cpuPercent, - }, nil -} - -func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bm.configLoader.GetConfig(modelName) - var backendId string - if exists { - backendId = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backendId = modelName - } - - if !strings.HasSuffix(backendId, ".bin") { - backendId = fmt.Sprintf("%s.bin", backendId) - } - - return backendId, nil -} - -func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) - if err != nil { - return nil, err - } - modelAddr := bm.modelLoader.CheckIsLoaded(backendId) - if modelAddr == "" { - return nil, fmt.Errorf("backend %s is not currently loaded", backendId) - } - - status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) - if rpcErr != nil { - log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bm.SampleLocalBackendProcess(backendId) - if slbErr != nil { - return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) - } - return &proto.StatusResponse{ - State: proto.StatusResponse_ERROR, - Memory: &proto.MemoryUsageData{ - Total: val.MemoryInfo.VMS, - Breakdown: map[string]uint64{ - "gopsutil-RSS": val.MemoryInfo.RSS, - }, - }, - }, nil - } - return status, nil -} - -func (bm BackendMonitor) ShutdownModel(modelName string) error { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) - if err != nil { - return err - } - return bm.modelLoader.ShutdownModel(backendId) -} diff --git a/core/services/config.go b/core/services/config.go deleted file mode 100644 index 66c9fd7fe7d5..000000000000 --- a/core/services/config.go +++ /dev/null @@ -1,157 +0,0 @@ -package services - -import ( - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" -) - -type ConfigLoader struct { - configs map[string]schema.Config - sync.Mutex -} - -func NewConfigLoader() *ConfigLoader { - return &ConfigLoader{ - configs: make(map[string]schema.Config), - } -} - -// TODO: check this is correct post-merge -func (cm *ConfigLoader) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := schema.ReadSingleConfigFile(file) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cm.configs[c.Name] = *c - return nil -} - -func (cm *ConfigLoader) GetConfig(m string) (schema.Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] - return v, exists -} - -func (cm *ConfigLoader) GetAllConfigs() []schema.Config { - cm.Lock() - defer cm.Unlock() - var res []schema.Config - for _, v := range cm.configs { - res = append(res, v) - } - return res -} - -func (cm *ConfigLoader) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() - var res []string - for k := range cm.configs { - res = append(res, k) - } - return res -} - -func (cm *ConfigLoader) LoadConfigs(path string) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { - continue - } - c, err := schema.ReadSingleConfigFile(filepath.Join(path, file.Name())) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} - -// Preload prepare models if they are not local but url or huggingface repositories -func (cm *ConfigLoader) Preload(modelPath string) error { - cm.Lock() - defer cm.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - for _, config := range cm.configs { - - // Download files and verify their SHA - for _, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { - return err - } - } - - modelURL := config.PredictionOptions.Model - modelURL = utils.ConvertURL(modelURL) - - if utils.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) - if err != nil { - return err - } - } - } - } - - return nil -} - -func (cl *ConfigLoader) LoadConfigFile(file string) error { - cl.Lock() - defer cl.Unlock() - c, err := schema.ReadConfigFile(file) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cl.configs[cc.Name] = *cc - } - return nil -} diff --git a/core/services/gallery.go b/core/services/gallery.go deleted file mode 100644 index edc4e6cc71d9..000000000000 --- a/core/services/gallery.go +++ /dev/null @@ -1,160 +0,0 @@ -package services - -import ( - "context" - "encoding/json" - "os" - "strings" - "sync" - - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/utils" - "gopkg.in/yaml.v2" -) - -type GalleryApplier struct { - modelPath string - sync.Mutex - C chan gallery.GalleryOp - statuses map[string]*gallery.GalleryOpStatus -} - -func NewGalleryApplier(modelPath string) *GalleryApplier { - return &GalleryApplier{ - modelPath: modelPath, - C: make(chan gallery.GalleryOp), - statuses: make(map[string]*gallery.GalleryOpStatus), - } -} - -func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) { - g.Lock() - defer g.Unlock() - g.statuses[s] = op -} - -func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses[s] -} - -func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses -} - -func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) { - go func() { - for { - select { - case <-c.Done(): - return - case op := <-g.C: - utils.ResetDownloadTimers() - - g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0}) - - // updates the status with an error - updateError := func(e error) { - g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) - } - - // displayDownload displays the download progress - progressCallback := func(fileName string, current string, total string, percentage float64) { - g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - utils.DisplayDownloadFunction(fileName, current, total, percentage) - } - - var err error - // if the request contains a gallery name, we apply the gallery from the gallery list - if op.GalleryName != "" { - if strings.Contains(op.GalleryName, "@") { - err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) - } else { - err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) - } - } else { - err = PrepareModel(g.modelPath, op.Req, cm, progressCallback) - } - - if err != nil { - updateError(err) - continue - } - - // Reload models - err = cm.LoadConfigs(g.modelPath) - if err != nil { - updateError(err) - continue - } - - g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100}) - } - } - }() -} - -type galleryModel struct { - gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 - ID string `json:"id"` -} - -func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetInstallableModelFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) -} - -func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { - var err error - for _, r := range requests { - utils.ResetDownloadTimers() - if r.ID == "" { - err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) - } else { - if strings.Contains(r.ID, "@") { - err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } else { - err = gallery.InstallModelFromGalleryByName( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } - } - } - return err -} - -func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error { - dat, err := os.ReadFile(s) - if err != nil { - return err - } - var requests []galleryModel - - if err := yaml.Unmarshal(dat, &requests); err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} - -func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error { - var requests []galleryModel - err := json.Unmarshal([]byte(s), &requests) - if err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} diff --git a/core/services/metrics.go b/core/services/metrics.go deleted file mode 100644 index 92ab571a1db0..000000000000 --- a/core/services/metrics.go +++ /dev/null @@ -1,29 +0,0 @@ -package services - -import ( - "github.com/go-skynet/LocalAI/pkg/schema" - "go.opentelemetry.io/otel/exporters/prometheus" - api "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/sdk/metric" -) - -// setupOTelSDK bootstraps the OpenTelemetry pipeline. -// If it does not return an error, make sure to call shutdown for proper cleanup. -func SetupMetrics() (*schema.LocalAIMetrics, error) { - exporter, err := prometheus.New() - if err != nil { - return nil, err - } - provider := metric.NewMeterProvider(metric.WithReader(exporter)) - meter := provider.Meter("github.com/go-skynet/LocalAI") - - apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls")) - if err != nil { - return nil, err - } - - return &schema.LocalAIMetrics{ - Meter: meter, - ApiTimeMetric: apiTimeMetric, - }, nil -} diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go deleted file mode 100644 index 218801ee436f..000000000000 --- a/core/startup/config_file_watcher.go +++ /dev/null @@ -1,100 +0,0 @@ -package startup - -import ( - "encoding/json" - "fmt" - "os" - "path" - - "github.com/fsnotify/fsnotify" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/imdario/mergo" - "github.com/rs/zerolog/log" -) - -type WatchConfigDirectoryCloser func() error - -func ReadApiKeysJson(configDir string, options *schema.StartupOptions) error { - fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json")) - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err == nil { - options.ApiKeys = append(options.ApiKeys, fileKeys...) - return nil - } - return err - } - return err -} - -func ReadExternalBackendsJson(configDir string, options *schema.StartupOptions) error { - fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json")) - if err != nil { - return err - } - // Parse JSON content from the file - var fileBackends map[string]string - err = json.Unmarshal(fileContent, &fileBackends) - if err != nil { - return err - } - err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends) - if err != nil { - return err - } - return nil -} - -var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *schema.StartupOptions) error{ - "api_keys.json": ReadApiKeysJson, - "external_backends.json": ReadExternalBackendsJson, -} - -func WatchConfigDirectory(configDir string, options *schema.StartupOptions) (WatchConfigDirectoryCloser, error) { - if len(configDir) == 0 { - return nil, fmt.Errorf("configDir blank") - } - configWatcher, err := fsnotify.NewWatcher() - if err != nil { - log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err) - } - ret := func() error { - configWatcher.Close() - return nil - } - - // Start listening for events. - go func() { - for { - select { - case event, ok := <-configWatcher.Events: - if !ok { - return - } - if event.Has(fsnotify.Write) { - for targetName, watchFn := range CONFIG_FILE_UPDATES { - if event.Name == targetName { - err := watchFn(configDir, options) - log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err) - } - } - } - case _, ok := <-configWatcher.Errors: - if !ok { - return - } - log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err) - } - } - }() - - // Add a path. - err = configWatcher.Add(configDir) - if err != nil { - return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err) - } - - return ret, nil -} diff --git a/core/startup/startup.go b/core/startup/startup.go deleted file mode 100644 index 85db50761625..000000000000 --- a/core/startup/startup.go +++ /dev/null @@ -1,93 +0,0 @@ -package startup - -import ( - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/pkg/assets" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -func Startup(opts ...schema.AppOption) (*services.ConfigLoader, *model.ModelLoader, *schema.StartupOptions, error) { - options := schema.NewStartupOptions(opts...) - - ml := model.NewModelLoader(options.ModelPath) - - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) - log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - - cl := services.NewConfigLoader() - if err := cl.LoadConfigs(options.ModelPath); err != nil { - log.Error().Msgf("error loading config files: %s", err.Error()) - } - - if options.ConfigFile != "" { - if err := cl.LoadConfigFile(options.ConfigFile); err != nil { - log.Error().Msgf("error loading config file: %s", err.Error()) - } - } - - if err := cl.Preload(options.ModelPath); err != nil { - log.Error().Msgf("error downloading models: %s", err.Error()) - } - - if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { - return nil, nil, nil, err - } - } - - if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { - return nil, nil, nil, err - } - } - - if options.Debug { - for _, v := range cl.ListConfigs() { - cfg, _ := cl.GetConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) - } - } - - if options.AssetsDestination != "" { - // Extract files from the embedded FS - err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) - if err != nil { - log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) - } - } - - // turn off any process that was started by GRPC if the context is canceled - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - ml.StopAllGRPC() - }() - - if options.WatchDog { - wd := model.NewWatchDog( - ml, - options.WatchDogBusyTimeout, - options.WatchDogIdleTimeout, - options.WatchDogBusy, - options.WatchDogIdle) - ml.SetWatchDog(wd) - go wd.Run() - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - wd.Shutdown() - }() - } - - return cl, ml, options, nil -} diff --git a/docs/content/advanced/development.md b/docs/content/advanced/development.md index afc6ce3bad8d..9f73b8a5b84f 100644 --- a/docs/content/advanced/development.md +++ b/docs/content/advanced/development.md @@ -17,53 +17,6 @@ This section will collect how-to, notes and development documentation We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages. -## LocalAI Project Structure - -**LocalAI is made of multiple components, developed in multiple repositories:** - -The core repository, containing the primary `local-ai` server code, gRPC stubs, this documentation website, and docker container building resources are all located at [mudler/LocalAI](https://github.com/mudler/LocalAI). - -As LocalAI is designed to make use of multiple, independent model galleries, those are maintained seperately. The following public model galleries are available for use: - -* [go-skynet/model-gallery](https://github.com/go-skynet/model-gallery) - The original gallery, the `golang` huggingface scraper ran into limits and was largely retired, so this now holds handmade yaml configs -* [dave-gray101/model-gallery](https://github.com/dave-gray101/model-gallery) - An automated gallery designed to track HuggingFace uploads and produce best-effort automatically generated configurations for LocalAI. It is designed to produce one LocalAI gallery per repository on HuggingFace. - -### Directory Structure of this Repo - -The core repository is broken up into the following primary chunks: - -* `/backend`: gRPC protobuf specification and gRPC backends. Subfolders for each language. -* **`/core`**: golang sourcecode for the core LocalAI application. Broken down below. -* `/docs`: localai.io website that you are reading now -* `/examples`: example code integrating LocalAI to other projects and/or developer samples and tools -* `/internal`: **here be dragons**. Don't touch this, it's used for automatic versioning. -* `/models`: _No code here!_ This is where models are installed! -* **`/pkg`**: golang sourcecode that is intended to be reusable or at least widely imported across LocalAI. Broken down below -* `/prompt-templates`: _No code here!_ This is where **example** prompt templates were historically stored. Somewhat obsolete these days, model-galleries tend to replace manually creating these? -* `/tests`: Does what it says on the tin. Please write tests and put them here when you do. - -The `core` folder is broken down further: - -* **`/core/backend`**: code that interacts with a gRPC backend to perform AI tasks. -* `/core/http`: code specifically related to the REST server -* `/core/http/endpoints`: Has two subdirectories, `openai` and `localai` for binding the respective endpoints to the correct backend or service. -* `/core/mqtt`: core specifically related to the MQTT server. Stub for now. Coming soon! -* **`/core/services`**: code implementing functionality performed by `local-ai` itself, rather than delegated to a backend. -* `/core/startup`: code related specifically to application startup of `local-ai`. Potentially to be refactored to become a part of `/core/services` at a later date, or not. - -The `pkg` folder is broken down further: - -* `/pkg/assets`: Currently contains a single function related to extracting files from archives. Potentially to be refactored to become a part of `/core/utils` at a later date? -* **`/pkg/datamodel`**: Contains the data types and definitions used by the LocalAI project. Imported widely! -* `/pkg/gallery`: Code related to interacting with a `model-gallery` -* `/pkg/grammar`: Code related to BNF / functions for LLM -* `/pkg/grpc`: base classes and interfaces for gRPC backends to implement -* `/pkg/langchain`: langchain related code in golang -* **`/pkg/model`**: Code related to loading and initializing a model and creating the appropriate gRPC backend. -* `/pkg/stablediffusion`: Code related to stablediffusion in golang. -* `/pkg/utils`: Every real programmer knows what they are going to find in here... it's our junk drawer of utility functions. - - ## Creating a gRPC backend LocalAI backends are `gRPC` servers. diff --git a/docs/content/features/text-to-audio.md b/docs/content/features/text-to-audio.md index 88aba2f1f7da..ab038d2f5e5b 100644 --- a/docs/content/features/text-to-audio.md +++ b/docs/content/features/text-to-audio.md @@ -20,7 +20,7 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ Returns an `audio/wav` file. -#### Text-To-Speech Setup +#### Setup LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`: @@ -52,8 +52,6 @@ Note: - The model name is case sensitive. - LocalAI must be compiled with the `GO_TAGS=tts` flag. -#### Music - LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech: ``` @@ -64,8 +62,7 @@ curl --request POST \ "backend": "transformers-musicgen", "model": "facebook/musicgen-medium", "input": "Cello Rave" -}' | aplay -``` +}' | aplay``` Future versions of LocalAI will expose additional control over audio generation beyond the text prompt. diff --git a/go.mod b/go.mod index 43a0f66937d5..250a2361796f 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.21 require ( github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df - github.com/fsnotify/fsnotify v1.7.0 github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 @@ -16,6 +15,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hpcloud/tail v1.0.0 github.com/imdario/mergo v0.3.16 + github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af @@ -63,6 +63,8 @@ require ( github.com/klauspost/pgzip v1.2.5 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nwaples/rardecode v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect diff --git a/go.sum b/go.sum index a98e3781f99b..fc00bf6e2ae6 100644 --- a/go.sum +++ b/go.sum @@ -24,9 +24,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= @@ -75,6 +74,7 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= @@ -88,6 +88,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= @@ -117,6 +119,11 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= diff --git a/main.go b/main.go index caae9d87c711..be4e4ed8417c 100644 --- a/main.go +++ b/main.go @@ -12,14 +12,14 @@ import ( "syscall" "time" - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/http" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/core/startup" + api "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" progressbar "github.com/schollz/progressbar/v3" @@ -190,12 +190,6 @@ func main() { EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, Value: false, }, - &cli.StringFlag{ - Name: "localai-config-dir", - Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.", - EnvVars: []string{"LOCALAI_CONFIG_DIR"}, - Value: "./config", - }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -214,54 +208,54 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - opts := []schema.AppOption{ - schema.WithConfigFile(ctx.String("config-file")), - schema.WithJSONStringPreload(ctx.String("preload-models")), - schema.WithYAMLConfigPreload(ctx.String("preload-models-config")), - schema.WithModelPath(ctx.String("models-path")), - schema.WithContextSize(ctx.Int("context-size")), - schema.WithDebug(ctx.Bool("debug")), - schema.WithImageDir(ctx.String("image-path")), - schema.WithAudioDir(ctx.String("audio-path")), - schema.WithF16(ctx.Bool("f16")), - schema.WithStringGalleries(ctx.String("galleries")), - schema.WithDisableMessage(false), - schema.WithCors(ctx.Bool("cors")), - schema.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - schema.WithThreads(ctx.Int("threads")), - schema.WithBackendAssets(backendAssets), - schema.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - schema.WithUploadLimitMB(ctx.Int("upload-limit")), - schema.WithApiKeys(ctx.StringSlice("api-keys")), - schema.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), + opts := []options.AppOption{ + options.WithConfigFile(ctx.String("config-file")), + options.WithJSONStringPreload(ctx.String("preload-models")), + options.WithYAMLConfigPreload(ctx.String("preload-models-config")), + options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), + options.WithContextSize(ctx.Int("context-size")), + options.WithDebug(ctx.Bool("debug")), + options.WithImageDir(ctx.String("image-path")), + options.WithAudioDir(ctx.String("audio-path")), + options.WithF16(ctx.Bool("f16")), + options.WithStringGalleries(ctx.String("galleries")), + options.WithDisableMessage(false), + options.WithCors(ctx.Bool("cors")), + options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + options.WithThreads(ctx.Int("threads")), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + options.WithUploadLimitMB(ctx.Int("upload-limit")), + options.WithApiKeys(ctx.StringSlice("api-keys")), + options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), } idleWatchDog := ctx.Bool("enable-watchdog-idle") busyWatchDog := ctx.Bool("enable-watchdog-busy") if idleWatchDog || busyWatchDog { - opts = append(opts, schema.EnableWatchDog) + opts = append(opts, options.EnableWatchDog) if idleWatchDog { - opts = append(opts, schema.EnableWatchDogIdleCheck) + opts = append(opts, options.EnableWatchDogIdleCheck) dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) if err != nil { return err } - opts = append(opts, schema.SetWatchDogIdleTimeout(dur)) + opts = append(opts, options.SetWatchDogIdleTimeout(dur)) } if busyWatchDog { - opts = append(opts, schema.EnableWatchDogBusyCheck) + opts = append(opts, options.EnableWatchDogBusyCheck) dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) if err != nil { return err } - opts = append(opts, schema.SetWatchDogBusyTimeout(dur)) + opts = append(opts, options.SetWatchDogBusyTimeout(dur)) } } if ctx.Bool("parallel-requests") { - opts = append(opts, schema.EnableParallelBackendRequests) + opts = append(opts, options.EnableParallelBackendRequests) } if ctx.Bool("single-active-backend") { - opts = append(opts, schema.EnableSingleBackend) + opts = append(opts, options.EnableSingleBackend) } externalgRPC := ctx.StringSlice("external-grpc-backends") @@ -269,42 +263,30 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit for _, v := range externalgRPC { backend := v[:strings.IndexByte(v, ':')] uri := v[strings.IndexByte(v, ':')+1:] - opts = append(opts, schema.WithExternalBackend(backend, uri)) + opts = append(opts, options.WithExternalBackend(backend, uri)) } if ctx.Bool("autoload-galleries") { - opts = append(opts, schema.EnableGalleriesAutoload) + opts = append(opts, options.EnableGalleriesAutoload) } if ctx.Bool("preload-backend-only") { - _, _, _, err := startup.Startup(opts...) + _, _, err := api.Startup(opts...) return err } - metrics, err := services.SetupMetrics() + metrics, err := metrics.SetupMetrics() if err != nil { return err } - opts = append(opts, schema.WithMetrics(metrics)) - - cl, ml, options, err := startup.Startup(opts...) - if err != nil { - return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) - } - - closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options) + opts = append(opts, options.WithMetrics(metrics)) - defer closeConfigWatcherFn() - if err != nil { - return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir")) - } - - appHTTP, err := http.App(cl, ml, options) + app, err := api.App(opts...) if err != nil { return err } - return appHTTP.Listen(ctx.String("address")) + return app.Listen(ctx.String("address")) }, Commands: []*cli.Command{ { @@ -402,18 +384,16 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit text := strings.Join(ctx.Args().Slice(), " ") - opts := &schema.StartupOptions{ - ModelPath: ctx.String("models-path"), + opts := &options.Option{ + Loader: model.NewModelLoader(ctx.String("models-path")), Context: context.Background(), AudioDir: outputDir, AssetsDestination: ctx.String("backend-assets-path"), } - loader := model.NewModelLoader(opts.ModelPath) + defer opts.Loader.StopAllGRPC() - defer loader.StopAllGRPC() - - filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, loader, opts) + filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts) if err != nil { return err } @@ -466,15 +446,13 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit language := ctx.String("language") threads := ctx.Int("threads") - opts := &schema.StartupOptions{ - ModelPath: ctx.String("models-path"), + opts := &options.Option{ + Loader: model.NewModelLoader(ctx.String("models-path")), Context: context.Background(), AssetsDestination: ctx.String("backend-assets-path"), } - ml := model.NewModelLoader(opts.ModelPath) - - cl := services.NewConfigLoader() + cl := config.NewConfigLoader() if err := cl.LoadConfigs(ctx.String("models-path")); err != nil { return err } @@ -486,9 +464,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit c.Threads = threads - defer ml.StopAllGRPC() + defer opts.Loader.StopAllGRPC() - tr, err := backend.ModelTranscription(filename, language, ml, c, opts) + tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts) if err != nil { return err } diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 000000000000..84b83161fdc1 --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,83 @@ +package metrics + +import ( + "context" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/prometheus" + api "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/sdk/metric" +) + +type Metrics struct { + meter api.Meter + apiTimeMetric api.Float64Histogram +} + +// setupOTelSDK bootstraps the OpenTelemetry pipeline. +// If it does not return an error, make sure to call shutdown for proper cleanup. +func SetupMetrics() (*Metrics, error) { + exporter, err := prometheus.New() + if err != nil { + return nil, err + } + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + meter := provider.Meter("github.com/go-skynet/LocalAI") + + apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls")) + if err != nil { + return nil, err + } + + return &Metrics{ + meter: meter, + apiTimeMetric: apiTimeMetric, + }, nil +} + +func MetricsHandler() fiber.Handler { + return adaptor.HTTPHandler(promhttp.Handler()) +} + +type apiMiddlewareConfig struct { + Filter func(c *fiber.Ctx) bool + metrics *Metrics +} + +func APIMiddleware(metrics *Metrics) fiber.Handler { + cfg := apiMiddlewareConfig{ + metrics: metrics, + Filter: func(c *fiber.Ctx) bool { + if c.Path() == "/metrics" { + return true + } + return false + }, + } + + return func(c *fiber.Ctx) error { + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + path := c.Path() + method := c.Method() + + start := time.Now() + err := c.Next() + elapsed := float64(time.Since(start)) / float64(time.Second) + cfg.metrics.ObserveAPICall(method, path, elapsed) + return err + } +} + +func (m *Metrics) ObserveAPICall(method string, path string, duration float64) { + opts := api.WithAttributes( + attribute.String("method", method), + attribute.String("path", path), + ) + m.apiTimeMetric.Record(context.Background(), duration, opts) +} diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 4aeb3172fa6a..7957ed59d638 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") - var config InstallableModel + var config Config if len(model.URL) > 0 { var err error - config, err = GetInstallableModelFromURL(model.URL) + config, err = GetGalleryConfigFromURL(model.URL) if err != nil { return err } @@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, if err != nil { return err } - config = InstallableModel{ + config = Config{ ConfigFile: string(reYamlConfig), Description: model.Description, License: model.License, diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 2e8770f17625..9a1697981614 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -1,9 +1,13 @@ package gallery import ( + "crypto/sha256" "fmt" + "hash" + "io" "os" "path/filepath" + "strconv" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" @@ -37,9 +41,9 @@ prompt_templates: content: "" */ -// InstallableModel is the model configuration which contains all the model details +// Config is the model configuration which contains all the model details // This configuration is read from the gallery endpoint and is used to download and install the model -type InstallableModel struct { +type Config struct { Description string `yaml:"description"` License string `yaml:"license"` URLs []string `yaml:"urls"` @@ -60,8 +64,8 @@ type PromptTemplate struct { Content string `yaml:"content"` } -func GetInstallableModelFromURL(url string) (InstallableModel, error) { - var config InstallableModel +func GetGalleryConfigFromURL(url string) (Config, error) { + var config Config err := utils.GetURI(url, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) @@ -72,7 +76,7 @@ func GetInstallableModelFromURL(url string) (InstallableModel, error) { return config, nil } -func ReadInstallableModelFile(filePath string) (*InstallableModel, error) { +func ReadConfigFile(filePath string) (*Config, error) { // Read the YAML file yamlFile, err := os.ReadFile(filePath) if err != nil { @@ -80,7 +84,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) { } // Unmarshal YAML data into a Config struct - var config InstallableModel + var config Config err = yaml.Unmarshal(yamlFile, &config) if err != nil { return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) @@ -89,7 +93,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) { return &config, nil } -func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { +func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { @@ -179,3 +183,54 @@ func InstallModel(basePath, nameOverride string, config *InstallableModel, confi return nil } + +type progressWriter struct { + fileName string + total int64 + written int64 + downloadStatus func(string, string, string, float64) + hash hash.Hash +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + n, err = pw.hash.Write(p) + pw.written += int64(n) + + if pw.total > 0 { + percentage := float64(pw.written) / float64(pw.total) * 100 + //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + } else { + pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) + } + + return +} + +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return strconv.FormatInt(bytes, 10) + " B" + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +func calculateSHA(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hash.Sum(nil)), nil +} diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index 96ed17e06b6a..f454c6111aea 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -16,7 +16,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) @@ -87,7 +87,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) @@ -103,7 +103,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) @@ -129,7 +129,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go deleted file mode 100644 index 873c356d3056..000000000000 --- a/pkg/gallery/op.go +++ /dev/null @@ -1,18 +0,0 @@ -package gallery - -type GalleryOp struct { - Req GalleryModel - Id string - Galleries []Gallery - GalleryName string -} - -type GalleryOpStatus struct { - FileName string `json:"file_name"` - Error error `json:"error"` - Processed bool `json:"processed"` - Message string `json:"message"` - Progress float64 `json:"progress"` - TotalFileSize string `json:"file_size"` - DownloadedFileSize string `json:"downloaded_size"` -} diff --git a/pkg/gallery/request_test.go b/pkg/gallery/request_test.go index 017167d908f9..a9d54e325042 100644 --- a/pkg/gallery/request_test.go +++ b/pkg/gallery/request_test.go @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() { Context("requests", func() { It("parses github with a branch", func() { req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} - e, err := GetInstallableModelFromURL(req.URL) + e, err := GetGalleryConfigFromURL(req.URL) Expect(err).ToNot(HaveOccurred()) Expect(e.Name).To(Equal("gpt4all-j")) }) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index e567a1251227..739d1cbbe6bb 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -6,8 +6,8 @@ import ( "fmt" "os" + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/schema" gopsutil "github.com/shirou/gopsutil/v3/process" ) @@ -53,9 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -// TODO CHECK THIS -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) { - return schema.WhisperResult{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { + return schema.Result{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index adf2d9de59ab..9eab356d487c 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,8 +7,8 @@ import ( "sync" "time" + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/schema" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.WhisperResult{} + tresult := &schema.Result{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { tks = append(tks, int(t)) } tresult.Segments = append(tresult.Segments, - schema.WhisperSegment{ + schema.Segment{ Text: s.Text, Id: int(s.Id), Start: time.Duration(s.Start), diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index fbd126a5bfed..a76261c15ce9 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,8 +1,8 @@ package grpc import ( + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/schema" ) type LLM interface { @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) + AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/grpc/proto/backend.pb.go b/pkg/grpc/proto/backend.pb.go index 2e4a2e9b22fa..b9569785eef6 100644 --- a/pkg/grpc/proto/backend.pb.go +++ b/pkg/grpc/proto/backend.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.26.0 -// protoc v4.26.0 +// protoc-gen-go v1.28.1 +// protoc v3.6.1 // source: backend.proto package proto diff --git a/pkg/grpc/proto/backend_grpc.pb.go b/pkg/grpc/proto/backend_grpc.pb.go index 41a1ba55aadd..d41f77a61446 100644 --- a/pkg/grpc/proto/backend_grpc.pb.go +++ b/pkg/grpc/proto/backend_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 -// - protoc v4.26.0 +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.6.1 // source: backend.proto package proto @@ -18,19 +18,6 @@ import ( // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 -const ( - Backend_Health_FullMethodName = "/backend.Backend/Health" - Backend_Predict_FullMethodName = "/backend.Backend/Predict" - Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel" - Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream" - Backend_Embedding_FullMethodName = "/backend.Backend/Embedding" - Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage" - Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription" - Backend_TTS_FullMethodName = "/backend.Backend/TTS" - Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString" - Backend_Status_FullMethodName = "/backend.Backend/Status" -) - // BackendClient is the client API for Backend service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. @@ -57,7 +44,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) if err != nil { return nil, err } @@ -66,7 +53,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) if err != nil { return nil, err } @@ -75,7 +62,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts .. func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) if err != nil { return nil, err } @@ -83,7 +70,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts .. } func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { - stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...) + stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) if err != nil { return nil, err } @@ -116,7 +103,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { out := new(EmbeddingResult) - err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) if err != nil { return nil, err } @@ -125,7 +112,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) if err != nil { return nil, err } @@ -134,7 +121,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { out := new(TranscriptResult) - err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) if err != nil { return nil, err } @@ -143,7 +130,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) if err != nil { return nil, err } @@ -152,7 +139,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { out := new(TokenizationResponse) - err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...) if err != nil { return nil, err } @@ -161,7 +148,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { out := new(StatusResponse) - err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...) if err != nil { return nil, err } @@ -242,7 +229,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_Health_FullMethodName, + FullMethod: "/backend.Backend/Health", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) @@ -260,7 +247,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_Predict_FullMethodName, + FullMethod: "/backend.Backend/Predict", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) @@ -278,7 +265,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_LoadModel_FullMethodName, + FullMethod: "/backend.Backend/LoadModel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) @@ -317,7 +304,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_Embedding_FullMethodName, + FullMethod: "/backend.Backend/Embedding", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) @@ -335,7 +322,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_GenerateImage_FullMethodName, + FullMethod: "/backend.Backend/GenerateImage", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) @@ -353,7 +340,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_AudioTranscription_FullMethodName, + FullMethod: "/backend.Backend/AudioTranscription", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) @@ -371,7 +358,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_TTS_FullMethodName, + FullMethod: "/backend.Backend/TTS", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) @@ -389,7 +376,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_TokenizeString_FullMethodName, + FullMethod: "/backend.Backend/TokenizeString", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) @@ -407,7 +394,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: Backend_Status_FullMethodName, + FullMethod: "/backend.Backend/Status", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 305393e03e42..c2182918f835 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/go-skynet/LocalAI/pkg/grpc" + grpc "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/hashicorp/go-multierror" "github.com/phayes/freeport" "github.com/rs/zerolog/log" @@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{ // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) { +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) { return func(modelName, modelFile string) (ModelAddress, error) { log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index aafa313b068a..d02f9e84c959 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,7 +10,7 @@ import ( "sync" "text/template" - "github.com/go-skynet/LocalAI/pkg/grammar" + grammar "github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grpc" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" diff --git a/pkg/model/options.go b/pkg/model/options.go index f7cfbe1ad372..5748be9be59e 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -6,7 +6,7 @@ import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) -type ModelOptions struct { +type Options struct { backendString string model string threads uint32 @@ -23,14 +23,14 @@ type ModelOptions struct { parallelRequests bool } -type Option func(*ModelOptions) +type Option func(*Options) -var EnableParallelRequests = func(o *ModelOptions) { +var EnableParallelRequests = func(o *Options) { o.parallelRequests = true } func WithExternalBackend(name string, uri string) Option { - return func(o *ModelOptions) { + return func(o *Options) { if o.externalBackends == nil { o.externalBackends = make(map[string]string) } @@ -38,81 +38,62 @@ func WithExternalBackend(name string, uri string) Option { } } -// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates -func WithExternalBackends(backends map[string]string, overwrite bool) Option { - return func(o *ModelOptions) { - if backends == nil { - return - } - if o.externalBackends == nil { - o.externalBackends = backends - return - } - for name, url := range backends { - _, exists := o.externalBackends[name] - if !exists || overwrite { - o.externalBackends[name] = url - } - } - } -} - func WithGRPCAttempts(attempts int) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.grpcAttempts = attempts } } func WithGRPCAttemptsDelay(delay int) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.grpcAttemptsDelay = delay } } func WithBackendString(backend string) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.backendString = backend } } func WithModel(modelFile string) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.model = modelFile } } func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.gRPCOptions = opts } } func WithThreads(threads uint32) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.threads = threads } } func WithAssetDir(assetDir string) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.assetDir = assetDir } } func WithContext(ctx context.Context) Option { - return func(o *ModelOptions) { + return func(o *Options) { o.context = ctx } } func WithSingleActiveBackend() Option { - return func(o *ModelOptions) { + return func(o *Options) { o.singleActiveBackend = true } } -func NewOptions(opts ...Option) *ModelOptions { - o := &ModelOptions{ +func NewOptions(opts ...Option) *Options { + o := &Options{ gRPCOptions: &pb.ModelOptions{}, context: context.Background(), grpcAttempts: 20, diff --git a/pkg/schema/localai.go b/pkg/schema/localai.go deleted file mode 100644 index c62e89ce874a..000000000000 --- a/pkg/schema/localai.go +++ /dev/null @@ -1,39 +0,0 @@ -package schema - -import ( - "context" - - gopsutil "github.com/shirou/gopsutil/v3/process" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" -) - -type BackendMonitorRequest struct { - Model string `json:"model" yaml:"model"` -} - -type BackendMonitorResponse struct { - MemoryInfo *gopsutil.MemoryInfoStat - MemoryPercent float32 - CPUPercent float64 -} - -type TTSRequest struct { - Model string `json:"model" yaml:"model"` - Input string `json:"input" yaml:"input"` - Backend string `json:"backend" yaml:"backend"` -} - -type LocalAIMetrics struct { - Meter metric.Meter - ApiTimeMetric metric.Float64Histogram -} - -func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) { - opts := metric.WithAttributes( - attribute.String("method", method), - attribute.String("path", path), - ) - m.ApiTimeMetric.Record(context.Background(), duration, opts) -} diff --git a/pkg/utils/file.go b/pkg/utils/file.go deleted file mode 100644 index fbeca6e5c9fd..000000000000 --- a/pkg/utils/file.go +++ /dev/null @@ -1,81 +0,0 @@ -package utils - -import ( - "bufio" - "encoding/base64" - "fmt" - "io" - "mime/multipart" - "net/http" - "os" - - "github.com/rs/zerolog/log" -) - -func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) { - - f, err := file.Open() - if err != nil { - return "", err - } - defer f.Close() - - // Create a temporary file in the requested directory: - outputFile, err := os.CreateTemp(tempDir, tempPattern) - if err != nil { - return "", err - } - defer outputFile.Close() - - if _, err := io.Copy(outputFile, f); err != nil { - log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err) - return "", err - } - - return outputFile.Name(), nil -} - -func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) { - if len(base64data) == 0 { - return "", fmt.Errorf("base64data empty?") - } - //base 64 decode the file and write it somewhere - // that we will cleanup - decoded, err := base64.StdEncoding.DecodeString(base64data) - if err != nil { - return "", err - } - // Create a temporary file in the requested directory: - outputFile, err := os.CreateTemp(tempDir, tempPattern) - if err != nil { - return "", err - } - defer outputFile.Close() - // write the base64 result - writer := bufio.NewWriter(outputFile) - _, err = writer.Write(decoded) - if err != nil { - return "", err - } - return outputFile.Name(), nil -} - -func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) { - // Get the data - resp, err := http.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // Create the file - out, err := os.CreateTemp(tempDir, tempPattern) - if err != nil { - return "", err - } - defer out.Close() - - // Write the body to file - _, err = io.Copy(out, resp.Body) - return out.Name(), err -} diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index 45e842bd1dcf..185e44b9610f 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -3,38 +3,18 @@ package utils import ( "crypto/md5" "crypto/sha256" - "encoding/base64" "fmt" "hash" "io" "net/http" "os" "path/filepath" - "slices" "strconv" "strings" "github.com/rs/zerolog/log" ) -const ( - HuggingFacePrefix = "huggingface://" - HTTPPrefix = "http://" - HTTPSPrefix = "https://" - GithubURI = "github:" - GithubURI2 = "github://" -) - -func getRecognizedURIPrefixes() []string { - return []string{ - HuggingFacePrefix, - HTTPPrefix, - HTTPSPrefix, - GithubURI, - GithubURI2, - } -} - func GetURI(url string, f func(url string, i []byte) error) error { url = ConvertURL(url) @@ -72,8 +52,20 @@ func GetURI(url string, f func(url string, i []byte) error) error { return f(url, body) } +const ( + HuggingFacePrefix = "huggingface://" + HTTPPrefix = "http://" + HTTPSPrefix = "https://" + GithubURI = "github:" + GithubURI2 = "github://" +) + func LooksLikeURL(s string) bool { - return slices.Contains(getRecognizedURIPrefixes(), s) + return strings.HasPrefix(s, HTTPPrefix) || + strings.HasPrefix(s, HTTPSPrefix) || + strings.HasPrefix(s, HuggingFacePrefix) || + strings.HasPrefix(s, GithubURI) || + strings.HasPrefix(s, GithubURI2) } func ConvertURL(s string) string { @@ -249,37 +241,6 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string, return nil } -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string -func GetBase64Image(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := http.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -} - type progressWriter struct { fileName string total int64 diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go index cb892244409f..c0fe7096a1d8 100644 --- a/tests/integration/reflect_test.go +++ b/tests/integration/reflect_test.go @@ -3,16 +3,16 @@ package integration_test import ( "reflect" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/schema" + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Integration Tests involving reflection in liue of code generation", func() { - Context("schema.TemplateConfig and model.TemplateType must stay in sync", func() { + Context("config.TemplateConfig and model.TemplateType must stay in sync", func() { - ttc := reflect.TypeOf(schema.TemplateConfig{}) + ttc := reflect.TypeOf(config.TemplateConfig{}) It("TemplateConfig and TemplateType should have the same number of valid values", func() { const lastValidTemplateType = model.IntegrationTestTemplate - 1