Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(refactor): drop unnecessary code in loader #4096

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ import (

func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {

var inferenceModel interface{}
var err error
opts := ModelOptions(backendConfig, appConfig)

opts := ModelOptions(backendConfig, appConfig, []model.Option{})

if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import (

func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{})

inferenceModel, err := loader.BackendLoader(
opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err := loader.Load(
opts...,
)
if err != nil {
Expand Down
18 changes: 2 additions & 16 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/mudler/LocalAI/core/schema"

"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
Expand All @@ -35,15 +34,6 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model

var inferenceModel grpc.Backend
var err error

opts := ModelOptions(c, o, []model.Option{})

if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}

// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
Expand All @@ -56,12 +46,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}
}

if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}

opts := ModelOptions(c, o)
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/rs/zerolog/log"
)

func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
name := c.Name
if name == "" {
name = c.Model
Expand Down
4 changes: 2 additions & 2 deletions core/backend/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
rerankModel, err := loader.BackendLoader(opts...)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ func SoundGeneration(
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})

soundGenModel, err := loader.BackendLoader(opts...)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
soundGenModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
Expand Down
19 changes: 9 additions & 10 deletions core/backend/stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ import (
)

func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) {
if storeName == "" {
storeName = "default"
}
if storeName == "" {
storeName = "default"
}

sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName),
}
sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName),
}

return sl.BackendLoader(sc...)
return sl.Load(sc...)
}

6 changes: 2 additions & 4 deletions core/backend/token_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ func TokenMetrics(
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
})
model, err := loader.BackendLoader(opts...)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
model, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
8 changes: 3 additions & 5 deletions core/backend/tokenize.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac
var inferenceModel grpc.Backend
var err error

opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
})
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))

if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
inferenceModel, err = loader.Load(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
inferenceModel, err = loader.Load(opts...)
}
if err != nil {
return schema.TokenizeResponse{}, err
Expand Down
4 changes: 2 additions & 2 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
backendConfig.Backend = model.WhisperBackend
}

opts := ModelOptions(backendConfig, appConfig, []model.Option{})
opts := ModelOptions(backendConfig, appConfig)

transcriptionModel, err := ml.BackendLoader(opts...)
transcriptionModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
8 changes: 2 additions & 6 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ func ModelTTS(
bb = model.PiperBackend
}

opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
})

ttsModel, err := loader.BackendLoader(opts...)
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
ttsModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
Expand Down
9 changes: 2 additions & 7 deletions core/startup/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode

log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)

o := backend.ModelOptions(*cfg, options, []model.Option{})
o := backend.ModelOptions(*cfg, options)

var backendErr error
if cfg.Backend != "" {
o = append(o, model.WithBackendString(cfg.Backend))
_, backendErr = ml.BackendLoader(o...)
} else {
_, backendErr = ml.GreedyLoader(o...)
}
_, backendErr = ml.Load(o...)
if backendErr != nil {
return nil, nil, nil, err
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
return orderBackends(backends)
}

func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)

log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
Expand Down Expand Up @@ -500,7 +500,7 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo
}
}

func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...)

// Return earlier if we have a model already loaded
Expand All @@ -513,6 +513,10 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {

ml.stopActiveBackends(o.modelID, o.singleActiveBackend)

if o.backendString != "" {
return ml.backendLoader(opts...)
}

var err error

// get backends embedded in the binary
Expand All @@ -536,7 +540,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
WithBackendString(key),
}...)

model, modelerr := ml.BackendLoader(options...)
model, modelerr := ml.backendLoader(options...)
if modelerr == nil && model != nil {
log.Info().Msgf("[%s] Loads OK", key)
return model, nil
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/stores_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
}

sl = model.NewModelLoader("")
sc, err = sl.BackendLoader(storeOpts...)
sc, err = sl.Load(storeOpts...)
Expect(err).ToNot(HaveOccurred())
Expect(sc).ToNot(BeNil())
})
Expand Down