diff --git a/embeddings.go b/embeddings.go index 5ba91f23..480ee0e2 100644 --- a/embeddings.go +++ b/embeddings.go @@ -12,9 +12,19 @@ import ( // to generate Embedding vectors. type EmbeddingModel int +func ConvertStr2EmbeddingModel(modelName string) EmbeddingModel { + if val, ok := stringToEnum[modelName]; ok { + return val + } + return Unknown +} + // String implements the fmt.Stringer interface. func (e EmbeddingModel) String() string { - return enumToString[e] + if val, ok := enumToString[e]; ok { + return val + } + return "Unknown" } // MarshalText implements the encoding.TextMarshaler interface.