-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat Implement postgres as vector storage (#146)
- Loading branch information
Showing
3 changed files
with
370 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
|
||
openaiembedder "github.com/henomis/lingoose/embedder/openai" | ||
"github.com/henomis/lingoose/index" | ||
indexoption "github.com/henomis/lingoose/index/option" | ||
"github.com/henomis/lingoose/index/vectordb/postgres" | ||
"github.com/henomis/lingoose/llm/openai" | ||
"github.com/henomis/lingoose/loader" | ||
"github.com/henomis/lingoose/prompt" | ||
"github.com/henomis/lingoose/textsplitter" | ||
// uncomment to use postgres | ||
// _ "github.com/lib/pq" | ||
) | ||
|
||
// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt | ||
|
||
func main() { | ||
|
||
connStr := "user=root sslmode=disable password=pass dbname=test host=localhost port=5432" | ||
db, err := sql.Open("postgres", connStr) | ||
if err != nil { | ||
panic(err) | ||
} | ||
defer db.Close() | ||
|
||
err = db.Ping() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
index := index.New( | ||
postgres.New( | ||
postgres.Options{ | ||
DB: db, | ||
Table: "test", | ||
CreateIndex: &postgres.CreateIndexOptions{ | ||
Dimension: 1536, | ||
Distance: postgres.DistanceCosine, | ||
}, | ||
}, | ||
), | ||
openaiembedder.New(openaiembedder.AdaEmbeddingV2), | ||
).WithIncludeContents(true) | ||
|
||
indexIsEmpty, err := index.IsEmpty(context.Background()) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
if indexIsEmpty { | ||
err = ingestData(index) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
query := "What is the purpose of the NATO Alliance?" | ||
similarities, err := index.Query( | ||
context.Background(), | ||
query, | ||
indexoption.WithTopK(3), | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
content := "" | ||
for _, similarity := range similarities { | ||
fmt.Printf("Similarity: %f\n", similarity.Score) | ||
fmt.Printf("Document: %s\n", similarity.Content()) | ||
fmt.Println("Metadata: ", similarity.Metadata) | ||
fmt.Println("ID: ", similarity.ID) | ||
fmt.Println("----------") | ||
content += similarity.Content() + "\n" | ||
} | ||
|
||
llmOpenAI := openai.NewCompletion().WithVerbose(true) | ||
|
||
prompt1 := prompt.NewPromptTemplate( | ||
"Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}"). | ||
WithInputs( | ||
map[string]string{ | ||
"query": query, | ||
"context": content, | ||
}, | ||
) | ||
|
||
err = prompt1.Format(nil) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
_, err = llmOpenAI.Completion(context.Background(), prompt1.String()) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
func ingestData(redisIndex *index.Index) error { | ||
documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20) | ||
|
||
documentChunks := textSplitter.SplitDocuments(documents) | ||
|
||
for _, doc := range documentChunks { | ||
fmt.Println(doc.Content) | ||
fmt.Println("----------") | ||
fmt.Println(doc.Metadata) | ||
fmt.Println("----------") | ||
fmt.Println() | ||
} | ||
|
||
return redisIndex.LoadFromDocuments(context.Background(), documentChunks) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
package postgres | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"encoding/json" | ||
"fmt" | ||
"strconv" | ||
"strings" | ||
|
||
"github.com/google/uuid" | ||
"github.com/henomis/lingoose/index" | ||
"github.com/henomis/lingoose/index/option" | ||
"github.com/henomis/lingoose/types" | ||
) | ||
|
||
type DB struct { | ||
db *sql.DB | ||
table string | ||
createIndex *CreateIndexOptions | ||
} | ||
|
||
type Distance string | ||
|
||
const ( | ||
DistanceCosine Distance = "<=>" | ||
DistanceInnerProduct Distance = "<#>" | ||
DistanceEuclidean Distance = "<->" | ||
) | ||
|
||
type CreateIndexOptions struct { | ||
Dimension uint64 | ||
Distance Distance | ||
} | ||
|
||
type Options struct { | ||
DB *sql.DB | ||
Table string | ||
CreateIndex *CreateIndexOptions | ||
} | ||
|
||
func New(options Options) *DB { | ||
return &DB{ | ||
db: options.DB, | ||
table: options.Table, | ||
createIndex: options.CreateIndex, | ||
} | ||
} | ||
|
||
func (d *DB) IsEmpty(ctx context.Context) (bool, error) { | ||
err := d.createIndexIfRequired(ctx) | ||
if err != nil { | ||
return true, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
var count int | ||
err = d.db.QueryRow( | ||
fmt.Sprintf("SELECT count(*) FROM %s", d.table), | ||
).Scan(&count) | ||
if err != nil { | ||
return true, err | ||
} | ||
|
||
return count == 0, nil | ||
} | ||
|
||
func (d *DB) Insert(ctx context.Context, datas []index.Data) error { | ||
err := d.createIndexIfRequired(ctx) | ||
if err != nil { | ||
return fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
var values []string | ||
for _, data := range datas { | ||
if data.ID == "" { | ||
id, errUUID := uuid.NewUUID() | ||
if errUUID != nil { | ||
return errUUID | ||
} | ||
data.ID = id.String() | ||
} | ||
|
||
jsonMetadata, marshalErr := json.Marshal(data.Metadata) | ||
if marshalErr != nil { | ||
return fmt.Errorf("%w: %w", index.ErrInternal, marshalErr) | ||
} | ||
|
||
values = append( | ||
values, | ||
fmt.Sprintf( | ||
"('%s','%s', '%s')", | ||
data.ID, | ||
floatToValues(data.Values), | ||
string(jsonMetadata), | ||
), | ||
) | ||
} | ||
|
||
_, err = d.db.ExecContext( | ||
ctx, | ||
fmt.Sprintf("INSERT INTO %s (id, embedding, metadata) VALUES %s", | ||
d.table, | ||
strings.Join(values, ","), | ||
), | ||
) | ||
if err != nil { | ||
return fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (d *DB) Search(ctx context.Context, values []float64, options *option.Options) (index.SearchResults, error) { | ||
return d.similaritySearch(ctx, values, options) | ||
} | ||
|
||
func (d *DB) similaritySearch( | ||
ctx context.Context, | ||
values []float64, | ||
opts *option.Options, | ||
) (index.SearchResults, error) { | ||
if opts.Filter == nil { | ||
opts.Filter = "" | ||
} | ||
|
||
queryVector := fmt.Sprintf("embedding %s '%s'", d.createIndex.Distance, floatToValues(values)) | ||
//nolint:gosec | ||
query := fmt.Sprintf( | ||
"SELECT id, embedding, metadata, %s AS score FROM %s %s ORDER BY %s LIMIT %d", | ||
queryVector, | ||
d.table, | ||
queryVector, | ||
opts.Filter, | ||
opts.TopK, | ||
) | ||
|
||
rows, err := d.db.QueryContext(ctx, query) | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
defer rows.Close() | ||
|
||
var results []index.SearchResult | ||
for rows.Next() { | ||
var id string | ||
var embedding string | ||
var jsonMetadata json.RawMessage | ||
var score float64 | ||
err = rows.Scan(&id, &embedding, &jsonMetadata, &score) | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
metadata := make(types.Meta) | ||
err = json.Unmarshal(jsonMetadata, &metadata) | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
var embeddingValues []float64 | ||
embeddingValues, err = valuesToFloats(embedding) | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
result := index.SearchResult{ | ||
Data: index.Data{ | ||
ID: id, | ||
Metadata: metadata, | ||
Values: embeddingValues, | ||
}, | ||
Score: score, | ||
} | ||
results = append(results, result) | ||
} | ||
|
||
err = rows.Err() | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
return results, nil | ||
} | ||
|
||
func (d *DB) createIndexIfRequired(ctx context.Context) error { | ||
if d.createIndex == nil { | ||
return nil | ||
} | ||
|
||
_, err := d.db.Exec("CREATE EXTENSION IF NOT EXISTS vector") | ||
if err != nil { | ||
return fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
_, err = d.db.ExecContext( | ||
ctx, | ||
fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UUID PRIMARY KEY, metadata json, embedding vector(%d))", | ||
d.table, d.createIndex.Dimension), | ||
) | ||
if err != nil { | ||
return fmt.Errorf("%w: %w", index.ErrInternal, err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func floatToValues(floats []float64) string { | ||
var b strings.Builder | ||
b.WriteString("[") | ||
for i, f := range floats { | ||
if i > 0 { | ||
b.WriteString(",") | ||
} | ||
b.WriteString(strconv.FormatFloat(f, 'f', -1, 64)) | ||
} | ||
b.WriteString("]") | ||
return b.String() | ||
} | ||
|
||
func valuesToFloats(s string) ([]float64, error) { | ||
s = strings.Trim(s, "[]") | ||
parts := strings.Split(s, ",") | ||
floats := make([]float64, len(parts)) | ||
for i, p := range parts { | ||
f, err := strconv.ParseFloat(p, 64) | ||
if err != nil { | ||
return nil, err | ||
} | ||
floats[i] = f | ||
} | ||
return floats, nil | ||
} |
Oops, something went wrong.