From 91465797d3a56db020655cb1b3c5dc454d47dbf0 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 7 Dec 2024 16:25:30 +0100 Subject: [PATCH] Move templating out of model loader Signed-off-by: Ettore Di Giacinto --- core/application.go | 38 ----- core/application/application.go | 39 +++++ .../config_file_watcher.go | 4 +- core/{startup => application}/startup.go | 77 +++------ core/cli/run.go | 8 +- core/http/app.go | 77 +++++---- core/http/app_test.go | 22 +-- core/http/endpoints/openai/chat.go | 6 +- core/http/endpoints/openai/completion.go | 7 +- core/http/endpoints/openai/edit.go | 6 +- core/http/routes/localai.go | 48 +++--- core/http/routes/openai.go | 154 ++++++++++++------ pkg/model/loader.go | 4 - pkg/templates/cache.go | 6 +- .../template.go => templates/evaluator.go} | 43 +++-- .../evaluator_test.go} | 16 +- 16 files changed, 292 insertions(+), 263 deletions(-) delete mode 100644 core/application.go create mode 100644 core/application/application.go rename core/{startup => application}/config_file_watcher.go (96%) rename core/{startup => application}/startup.go (62%) rename pkg/{model/template.go => templates/evaluator.go} (81%) rename pkg/{model/template_test.go => templates/evaluator_test.go} (91%) diff --git a/core/application.go b/core/application.go deleted file mode 100644 index e4efbdd0ab93..000000000000 --- a/core/application.go +++ /dev/null @@ -1,38 +0,0 @@ -package core - -import ( - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/model" -) - -// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy -// Perhaps a proper DI system is worth it in the future, but for now keep things simple. -type Application struct { - - // Application-Level Config - ApplicationConfig *config.ApplicationConfig - // ApplicationState *ApplicationState - - // Core Low-Level Services - BackendConfigLoader *config.BackendConfigLoader - ModelLoader *model.ModelLoader - - // Backend Services - // EmbeddingsBackendService *backend.EmbeddingsBackendService - // ImageGenerationBackendService *backend.ImageGenerationBackendService - // LLMBackendService *backend.LLMBackendService - // TranscriptionBackendService *backend.TranscriptionBackendService - // TextToSpeechBackendService *backend.TextToSpeechBackendService - - // LocalAI System Services - BackendMonitorService *services.BackendMonitorService - GalleryService *services.GalleryService - LocalAIMetricsService *services.LocalAIMetricsService - // OpenAIService *services.OpenAIService -} - -// TODO [NEXT PR?]: Break up ApplicationConfig. -// Migrate over stuff that is not set via config at all - especially runtime stuff -type ApplicationState struct { -} diff --git a/core/application/application.go b/core/application/application.go new file mode 100644 index 000000000000..62b333310179 --- /dev/null +++ b/core/application/application.go @@ -0,0 +1,39 @@ +package application + +import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" +) + +type Application struct { + backendLoader *config.BackendConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig + templatesEvaluator *templates.Evaluator +} + +func newApplication(appConfig *config.ApplicationConfig) *Application { + return &Application{ + backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath), + modelLoader: model.NewModelLoader(appConfig.ModelPath), + applicationConfig: appConfig, + templatesEvaluator: templates.NewEvaluator(templates.NewTemplateCache(appConfig.ModelPath)), + } +} + +func (a *Application) BackendLoader() *config.BackendConfigLoader { + return a.backendLoader +} + +func (a *Application) ModelLoader() *model.ModelLoader { + return a.modelLoader +} + +func (a *Application) ApplicationConfig() *config.ApplicationConfig { + return a.applicationConfig +} + +func (a *Application) TemplatesEvaluator() *templates.Evaluator { + return a.templatesEvaluator +} diff --git a/core/startup/config_file_watcher.go b/core/application/config_file_watcher.go similarity index 96% rename from core/startup/config_file_watcher.go rename to core/application/config_file_watcher.go index df72483f7512..46f29b101acb 100644 --- a/core/startup/config_file_watcher.go +++ b/core/application/config_file_watcher.go @@ -1,4 +1,4 @@ -package startup +package application import ( "encoding/json" @@ -8,8 +8,8 @@ import ( "path/filepath" "time" - "github.com/fsnotify/fsnotify" "dario.cat/mergo" + "github.com/fsnotify/fsnotify" "github.com/mudler/LocalAI/core/config" "github.com/rs/zerolog/log" ) diff --git a/core/startup/startup.go b/core/application/startup.go similarity index 62% rename from core/startup/startup.go rename to core/application/startup.go index 0eb5fa585585..cd52d37ae962 100644 --- a/core/startup/startup.go +++ b/core/application/startup.go @@ -1,15 +1,15 @@ -package startup +package application import ( "fmt" "os" - "github.com/mudler/LocalAI/core" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/assets" + "github.com/mudler/LocalAI/pkg/library" "github.com/mudler/LocalAI/pkg/model" pkgStartup "github.com/mudler/LocalAI/pkg/startup" @@ -17,8 +17,9 @@ import ( "github.com/rs/zerolog/log" ) -func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { +func New(opts ...config.AppOption) (*Application, error) { options := config.NewApplicationConfig(opts...) + application := newApplication(options) log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) @@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode // Make sure directories exists if options.ModelPath == "" { - return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") + return nil, fmt.Errorf("options.ModelPath cannot be empty") } err = os.MkdirAll(options.ModelPath, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) + return nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.ImageDir != "" { err := os.MkdirAll(options.ImageDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) + return nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.AudioDir != "" { err := os.MkdirAll(options.AudioDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) + return nil, fmt.Errorf("unable to create AudioDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) + return nil, fmt.Errorf("unable to create UploadDir: %q", err) } } @@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Error().Err(err).Msg("error installing models") } - cl := config.NewBackendConfigLoader(options.ModelPath) - ml := model.NewModelLoader(options.ModelPath) - configLoaderOpts := options.ToConfigLoaderOptions() - if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config files") } if options.ConfigFile != "" { - if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := cl.Preload(options.ModelPath); err != nil { + if err := application.BackendLoader().Preload(options.ModelPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil { - return nil, nil, nil, err + return nil, err } } if options.PreloadModelsFromPath != "" { if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil { - return nil, nil, nil, err + return nil, err } } if options.Debug { - for _, v := range cl.GetAllBackendConfigs() { + for _, v := range application.BackendLoader().GetAllBackendConfigs() { log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v) } } @@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") - err := ml.StopAllGRPC() + err := application.ModelLoader().StopAllGRPC() if err != nil { log.Error().Err(err).Msg("error while stopping all grpc backends") } @@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.WatchDog { wd := model.NewWatchDog( - ml, + application.ModelLoader(), options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) - ml.SetWatchDog(wd) + application.ModelLoader().SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() @@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.LoadToMemory != nil { for _, m := range options.LoadToMemory { - cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath, + cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath, config.LoadOptionDebug(options.Debug), config.LoadOptionThreads(options.Threads), config.LoadOptionContextSize(options.ContextSize), @@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode config.ModelPath(options.ModelPath), ) if err != nil { - return nil, nil, nil, err + return nil, err } log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) @@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode o := backend.ModelOptions(*cfg, options) var backendErr error - _, backendErr = ml.Load(o...) + _, backendErr = application.ModelLoader().Load(o...) if backendErr != nil { - return nil, nil, nil, err + return nil, err } } } @@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode startWatcher(options) log.Info().Msg("core/startup process completed!") - return cl, ml, options, nil + return application, nil } func startWatcher(options *config.ApplicationConfig) { @@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) { log.Error().Err(err).Msg("failed creating watcher") } } - -// In Lieu of a proper DI framework, this function wires up the Application manually. -// This is in core/startup rather than core/state.go to keep package references clean! -func createApplication(appConfig *config.ApplicationConfig) *core.Application { - app := &core.Application{ - ApplicationConfig: appConfig, - BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath), - ModelLoader: model.NewModelLoader(appConfig.ModelPath), - } - - var err error - - // app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - - app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.GalleryService = services.NewGalleryService(app.ApplicationConfig) - // app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) - - app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() - if err != nil { - log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.") - } - - return app -} diff --git a/core/cli/run.go b/core/cli/run.go index b2d439a05b9f..a0e16155e2c6 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -6,12 +6,12 @@ import ( "strings" "time" + "github.com/mudler/LocalAI/core/application" cli_api "github.com/mudler/LocalAI/core/cli/api" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/p2p" - "github.com/mudler/LocalAI/core/startup" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { } if r.PreloadBackendOnly { - _, _, _, err := startup.Startup(opts...) + _, err := application.New(opts...) return err } - cl, ml, options, err := startup.Startup(opts...) + app, err := application.New(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } - appHTTP, err := http.App(cl, ml, options) + appHTTP, err := http.API(app) if err != nil { log.Error().Err(err).Msg("error during HTTP App construction") return err diff --git a/core/http/app.go b/core/http/app.go index 2ba2c2b99535..a2d8b87a2f73 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -14,10 +14,9 @@ import ( "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" - "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/model" "github.com/gofiber/contrib/fiberzerolog" "github.com/gofiber/fiber/v2" @@ -49,18 +48,18 @@ var embedDirStatic embed.FS // @in header // @name Authorization -func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { +func API(application *application.Application) (*fiber.App, error) { fiberCfg := fiber.Config{ Views: renderEngine(), - BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB // We disable the Fiber startup message as it does not conform to structured logging. // We register a startup log line with connection information in the OnListen hook to keep things user friendly though DisableStartupMessage: true, // Override default error handler } - if !appConfig.OpaqueErrors { + if !application.ApplicationConfig().OpaqueErrors { // Normally, return errors as JSON responses fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi } } - app := fiber.New(fiberCfg) + router := fiber.New(fiberCfg) - app.Hooks().OnListen(func(listenData fiber.ListenData) error { + router.Hooks().OnListen(func(listenData fiber.ListenData) error { scheme := "http" if listenData.TLS { scheme = "https" @@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Have Fiber use zerolog like the rest of the application rather than it's built-in logger logger := log.Logger - app.Use(fiberzerolog.New(fiberzerolog.Config{ + router.Use(fiberzerolog.New(fiberzerolog.Config{ Logger: &logger, })) // Default middleware config - if !appConfig.Debug { - app.Use(recover.New()) + if !application.ApplicationConfig().Debug { + router.Use(recover.New()) } - if !appConfig.DisableMetrics { + if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() if err != nil { return nil, err } if metricsService != nil { - app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) - app.Hooks().OnShutdown(func() error { + router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + router.Hooks().OnShutdown(func() error { return metricsService.Shutdown() }) } } // Health Checks should always be exempt from auth, so register these first - routes.HealthRoutes(app) + routes.HealthRoutes(router) - kaConfig, err := middleware.GetKeyAuthConfig(appConfig) + kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig()) if err != nil || kaConfig == nil { return nil, fmt.Errorf("failed to create key auth config: %w", err) } // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration - app.Use(v2keyauth.New(*kaConfig)) + router.Use(v2keyauth.New(*kaConfig)) - if appConfig.CORS { + if application.ApplicationConfig().CORS { var c func(ctx *fiber.Ctx) error - if appConfig.CORSAllowOrigins == "" { + if application.ApplicationConfig().CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins}) } - app.Use(c) + router.Use(c) } - if appConfig.CSRF { + if application.ApplicationConfig().CSRF { log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests") - app.Use(csrf.New()) + router.Use(csrf.New()) } // Load config jsons - utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) - - galleryService := services.NewGalleryService(appConfig) - galleryService.Start(appConfig.Context, cl) - - routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig) - routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService) - routes.RegisterOpenAIRoutes(app, cl, ml, appConfig) - if !appConfig.DisableWebUI { - routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService) + utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) + + galleryService := services.NewGalleryService(application.ApplicationConfig()) + galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader()) + + routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) + routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) + routes.RegisterOpenAIRoutes(router, application) + if !application.ApplicationConfig().DisableWebUI { + routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) } - routes.RegisterJINARoutes(app, cl, ml, appConfig) + routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) httpFS := http.FS(embedDirStatic) - app.Use(favicon.New(favicon.Config{ + router.Use(favicon.New(favicon.Config{ URL: "/favicon.ico", FileSystem: httpFS, File: "static/favicon.ico", })) - app.Use("/static", filesystem.New(filesystem.Config{ + router.Use("/static", filesystem.New(filesystem.Config{ Root: httpFS, PathPrefix: "static", Browse: true, @@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Define a custom 404 handler // Note: keep this at the bottom! - app.Use(notFoundHandler) + router.Use(notFoundHandler) - return app, nil + return router, nil } diff --git a/core/http/app_test.go b/core/http/app_test.go index 83fb0e73bc74..7669872b3748 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -12,15 +12,14 @@ import ( "path/filepath" "runtime" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/startup" "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" - "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" @@ -252,9 +251,6 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string var modelDir string - var bcl *config.BackendConfigLoader - var ml *model.ModelLoader - var applicationConfig *config.ApplicationConfig commonOpts := []config.AppOption{ config.WithDebug(true), @@ -300,7 +296,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithGalleries(galleries), @@ -310,7 +306,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -641,7 +637,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithAudioDir(tmpdir), @@ -652,7 +648,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -772,14 +768,14 @@ var _ = Describe("API test", func() { var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -990,14 +986,14 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithModelPath(modelPath), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 85834b9be70d..b334a8d75ec1 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -14,6 +14,8 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/templates" + model "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" @@ -24,7 +26,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { var id, textContentToReturn string var created int @@ -298,7 +300,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // If we are using the tokenizer template, we don't need to process the messages // unless we are processing functions if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn { - predInput = ml.TemplateMessages(input.Messages, config, funcs, shouldUseFn) + predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn) log.Debug().Msgf("Prompt (after templating): %s", predInput) if shouldUseFn && config.Grammar != "" { diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 4567bd4cffaf..04ebc847905f 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -16,6 +16,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -25,7 +26,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { id := uuid.New().String() created := int(time.Now().Unix()) @@ -101,7 +102,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a predInput := config.PromptStrings[0] - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, *config, model.PromptTemplateData{ + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ Input: predInput, SystemPrompt: config.SystemPrompt, }) @@ -152,7 +153,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a totalTokenUsage := backend.TokenUsage{} for k, i := range config.PromptStrings { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, *config, model.PromptTemplateData{ + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ SystemPrompt: config.SystemPrompt, Input: i, }) diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 0a459c434cce..a6d609fbfeb3 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" "github.com/rs/zerolog/log" ) @@ -21,7 +22,8 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/edits [post] -func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { @@ -39,7 +41,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf totalTokenUsage := backend.TokenUsage{} for _, i := range config.InputStrings { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, *config, model.PromptTemplateData{ + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{ Input: i, Instruction: input.Instruction, SystemPrompt: config.SystemPrompt, diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index e7097741aa16..2ea9896a2ec8 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -11,62 +11,62 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func RegisterLocalAIRoutes(app *fiber.App, +func RegisterLocalAIRoutes(router *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService) { - app.Get("/swagger/*", swagger.HandlerDefault) // default + router.Get("/swagger/*", swagger.HandlerDefault) // default // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) - app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) - app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) + router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) - app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) - app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) + router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) + router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint()) + router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint()) + router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) + router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) } - app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) - app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) + router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) + router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) // Stores sl := model.NewModelLoader("") - app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) - app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) - app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) - app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) + router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) + router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) + router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) + router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) if !appConfig.DisableMetrics { - app.Get("/metrics", localai.LocalAIMetricsEndpoint()) + router.Get("/metrics", localai.LocalAIMetricsEndpoint()) } // Experimental Backend Statistics Module backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now - app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) - app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) + router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) + router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) // p2p if p2p.IsP2PEnabled() { - app.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) - app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) + router.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) + router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) } - app.Get("/version", func(c *fiber.Ctx) error { + router.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) - app.Get("/system", localai.SystemInformations(ml, appConfig)) + router.Get("/system", localai.SystemInformations(ml, appConfig)) // misc - app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) + router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 081daf70d80c..5ff301b673bc 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -2,84 +2,134 @@ package routes import ( "github.com/gofiber/fiber/v2" - "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" - "github.com/mudler/LocalAI/pkg/model" ) func RegisterOpenAIRoutes(app *fiber.App, - cl *config.BackendConfigLoader, - ml *model.ModelLoader, - appConfig *config.ApplicationConfig) { + application *application.Application) { // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) - app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/v1/chat/completions", + openai.ChatEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/chat/completions", + openai.ChatEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // edit - app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig)) - app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/v1/edits", + openai.EditEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/edits", + openai.EditEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // assistant - app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // files - app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig)) - app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig)) - app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/files", openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) - app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) - app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) - app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) - app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) // completion - app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/v1/engines/:model/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // embeddings - app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // audio - app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig)) - app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // images - app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig)) + app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) - if appConfig.ImageDir != "" { - app.Static("/generated-images", appConfig.ImageDir) + if application.ApplicationConfig().ImageDir != "" { + app.Static("/generated-images", application.ApplicationConfig().ImageDir) } - if appConfig.AudioDir != "" { - app.Static("/generated-audio", appConfig.AudioDir) + if application.ApplicationConfig().AudioDir != "" { + app.Static("/generated-audio", application.ApplicationConfig().AudioDir) } // List models - app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", openai.ListModelsEndpoint(cl, ml)) + app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) + app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b32e3745efc9..d62f52b23855 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/mudler/LocalAI/pkg/templates" - "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" @@ -23,7 +21,6 @@ type ModelLoader struct { ModelPath string mu sync.Mutex models map[string]*Model - templates *templates.TemplateCache wd *WatchDog } @@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader { nml := &ModelLoader{ ModelPath: modelPath, models: make(map[string]*Model), - templates: templates.NewTemplateCache(modelPath), } return nml diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go index 30ef07b2923b..82c48b2b1b30 100644 --- a/pkg/templates/cache.go +++ b/pkg/templates/cache.go @@ -42,6 +42,10 @@ func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) { } } +func (tc *TemplateCache) ExistsInModelPath(s string) bool { + return utils.ExistsInPath(tc.templatesPath, s) +} + func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) { tc.mu.Lock() defer tc.mu.Unlock() @@ -88,7 +92,7 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat } // can either be a file in the system or a string with the template - if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) { + if tc.ExistsInModelPath(modelTemplateFile) { d, err := os.ReadFile(file) if err != nil { return err diff --git a/pkg/model/template.go b/pkg/templates/evaluator.go similarity index 81% rename from pkg/model/template.go rename to pkg/templates/evaluator.go index 28864f5593af..7b2089b3ac0b 100644 --- a/pkg/model/template.go +++ b/pkg/templates/evaluator.go @@ -1,4 +1,4 @@ -package model +package templates import ( "encoding/json" @@ -8,7 +8,6 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/templates" "github.com/rs/zerolog/log" ) @@ -37,18 +36,28 @@ type ChatMessageTemplateData struct { } const ( - ChatPromptTemplate templates.TemplateType = iota + ChatPromptTemplate TemplateType = iota ChatMessageTemplate CompletionPromptTemplate EditPromptTemplate FunctionsPromptTemplate ) -func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) { +type Evaluator struct { + cache *TemplateCache +} + +func NewEvaluator(cache *TemplateCache) *Evaluator { + return &Evaluator{ + cache: cache, + } +} + +func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) { template := "" // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + if e.cache.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { template = config.Model } @@ -76,17 +85,17 @@ func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.Template } if config.TemplateConfig.JinjaTemplate { - return ml.EvaluateJinjaTemplateForPrompt(templateType, template, in) + return e.EvaluateJinjaTemplateForPrompt(templateType, template, in) } - return ml.templates.EvaluateTemplate(templateType, template, in) + return e.cache.EvaluateTemplate(templateType, template, in) } -func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { - return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) +func (e *Evaluator) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { + return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) } -func (ml *ModelLoader) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData) (string, error) { +func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData) (string, error) { conversation := make(map[string]interface{}) messages := make([]map[string]interface{}, len(messageData)) @@ -108,20 +117,20 @@ func (ml *ModelLoader) templateJinjaChat(templateName string, messageData []Chat conversation["messages"] = messages - return ml.templates.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) + return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) } -func (ml *ModelLoader) EvaluateJinjaTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { +func (e *Evaluator) EvaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { conversation := make(map[string]interface{}) conversation["system_prompt"] = in.SystemPrompt conversation["content"] = in.Input - return ml.templates.EvaluateJinjaTemplate(templateType, templateName, conversation) + return e.cache.EvaluateJinjaTemplate(templateType, templateName, conversation) } -func (ml *ModelLoader) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string { +func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string { if config.TemplateConfig.JinjaTemplate { var messageData []ChatMessageTemplateData @@ -143,7 +152,7 @@ func (ml *ModelLoader) TemplateMessages(messages []schema.Message, config *confi }) } - templatedInput, err := ml.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData) + templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData) if err == nil { return templatedInput } @@ -186,7 +195,7 @@ func (ml *ModelLoader) TemplateMessages(messages []schema.Message, config *confi Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), MessageIndex: messageIndex, } - templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + templatedChatMessage, err := e.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) if err != nil { log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") } else { @@ -266,7 +275,7 @@ func (ml *ModelLoader) TemplateMessages(messages []schema.Message, config *confi promptTemplate = FunctionsPromptTemplate } - templatedInput, err := ml.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{ + templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{ SystemPrompt: config.SystemPrompt, SuppressSystemPrompt: suppressConfigSystemPrompt, Input: predInput, diff --git a/pkg/model/template_test.go b/pkg/templates/evaluator_test.go similarity index 91% rename from pkg/model/template_test.go rename to pkg/templates/evaluator_test.go index 1142ed0c529b..06551e4d5abe 100644 --- a/pkg/model/template_test.go +++ b/pkg/templates/evaluator_test.go @@ -1,7 +1,7 @@ -package model_test +package templates_test import ( - . "github.com/mudler/LocalAI/pkg/model" + . "github.com/mudler/LocalAI/pkg/templates" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -167,28 +167,28 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in var _ = Describe("Templates", func() { Context("chat message ChatML", func() { - var modelLoader *ModelLoader + var evaluator *Evaluator BeforeEach(func() { - modelLoader = NewModelLoader("") + evaluator = NewEvaluator(NewTemplateCache("")) }) for key := range chatMLTestMatch { foo := chatMLTestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) + templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) Expect(err).ToNot(HaveOccurred()) Expect(templated).To(Equal(foo["expected"]), templated) }) } }) Context("chat message llama3", func() { - var modelLoader *ModelLoader + var evaluator *Evaluator BeforeEach(func() { - modelLoader = NewModelLoader("") + evaluator = NewEvaluator(NewTemplateCache("")) }) for key := range llama3TestMatch { foo := llama3TestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) + templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) Expect(err).ToNot(HaveOccurred()) Expect(templated).To(Equal(foo["expected"]), templated) })