Skip to content

Commit

Permalink
Feat Implement postgres as vector storage (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Nov 2, 2023
1 parent 83195d2 commit 940fca7
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 7 deletions.
123 changes: 123 additions & 0 deletions examples/embeddings/postgres/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package main

import (
"context"
"database/sql"
"fmt"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/index/vectordb/postgres"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/loader"
"github.com/henomis/lingoose/prompt"
"github.com/henomis/lingoose/textsplitter"
// uncomment to use postgres
// _ "github.com/lib/pq"
)

// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt

func main() {

connStr := "user=root sslmode=disable password=pass dbname=test host=localhost port=5432"
db, err := sql.Open("postgres", connStr)
if err != nil {
panic(err)
}
defer db.Close()

err = db.Ping()
if err != nil {
panic(err)
}

index := index.New(
postgres.New(
postgres.Options{
DB: db,
Table: "test",
CreateIndex: &postgres.CreateIndexOptions{
Dimension: 1536,
Distance: postgres.DistanceCosine,
},
},
),
openaiembedder.New(openaiembedder.AdaEmbeddingV2),
).WithIncludeContents(true)

indexIsEmpty, err := index.IsEmpty(context.Background())
if err != nil {
panic(err)
}

if indexIsEmpty {
err = ingestData(index)
if err != nil {
panic(err)
}
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := index.Query(
context.Background(),
query,
indexoption.WithTopK(3),
)
if err != nil {
panic(err)
}

content := ""
for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("ID: ", similarity.ID)
fmt.Println("----------")
content += similarity.Content() + "\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)

prompt1 := prompt.NewPromptTemplate(
"Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}").
WithInputs(
map[string]string{
"query": query,
"context": content,
},
)

err = prompt1.Format(nil)
if err != nil {
panic(err)
}

_, err = llmOpenAI.Completion(context.Background(), prompt1.String())
if err != nil {
panic(err)
}
}

func ingestData(redisIndex *index.Index) error {
documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background())
if err != nil {
return err
}

textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20)

documentChunks := textSplitter.SplitDocuments(documents)

for _, doc := range documentChunks {
fmt.Println(doc.Content)
fmt.Println("----------")
fmt.Println(doc.Metadata)
fmt.Println("----------")
fmt.Println()
}

return redisIndex.LoadFromDocuments(context.Background(), documentChunks)
}
232 changes: 232 additions & 0 deletions index/vectordb/postgres/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package postgres

import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
"strings"

"github.com/google/uuid"
"github.com/henomis/lingoose/index"
"github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/types"
)

type DB struct {
db *sql.DB
table string
createIndex *CreateIndexOptions
}

type Distance string

const (
DistanceCosine Distance = "<=>"
DistanceInnerProduct Distance = "<#>"
DistanceEuclidean Distance = "<->"
)

type CreateIndexOptions struct {
Dimension uint64
Distance Distance
}

type Options struct {
DB *sql.DB
Table string
CreateIndex *CreateIndexOptions
}

func New(options Options) *DB {
return &DB{
db: options.DB,
table: options.Table,
createIndex: options.CreateIndex,
}
}

func (d *DB) IsEmpty(ctx context.Context) (bool, error) {
err := d.createIndexIfRequired(ctx)
if err != nil {
return true, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

var count int
err = d.db.QueryRow(
fmt.Sprintf("SELECT count(*) FROM %s", d.table),
).Scan(&count)
if err != nil {
return true, err
}

return count == 0, nil
}

func (d *DB) Insert(ctx context.Context, datas []index.Data) error {
err := d.createIndexIfRequired(ctx)
if err != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, err)
}

var values []string
for _, data := range datas {
if data.ID == "" {
id, errUUID := uuid.NewUUID()
if errUUID != nil {
return errUUID
}
data.ID = id.String()
}

jsonMetadata, marshalErr := json.Marshal(data.Metadata)
if marshalErr != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, marshalErr)
}

values = append(
values,
fmt.Sprintf(
"('%s','%s', '%s')",
data.ID,
floatToValues(data.Values),
string(jsonMetadata),
),
)
}

_, err = d.db.ExecContext(
ctx,
fmt.Sprintf("INSERT INTO %s (id, embedding, metadata) VALUES %s",
d.table,
strings.Join(values, ","),
),
)
if err != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, err)
}

return nil
}

func (d *DB) Search(ctx context.Context, values []float64, options *option.Options) (index.SearchResults, error) {
return d.similaritySearch(ctx, values, options)
}

func (d *DB) similaritySearch(
ctx context.Context,
values []float64,
opts *option.Options,
) (index.SearchResults, error) {
if opts.Filter == nil {
opts.Filter = ""
}

queryVector := fmt.Sprintf("embedding %s '%s'", d.createIndex.Distance, floatToValues(values))
//nolint:gosec
query := fmt.Sprintf(
"SELECT id, embedding, metadata, %s AS score FROM %s %s ORDER BY %s LIMIT %d",
queryVector,
d.table,
queryVector,
opts.Filter,
opts.TopK,
)

rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}
defer rows.Close()

var results []index.SearchResult
for rows.Next() {
var id string
var embedding string
var jsonMetadata json.RawMessage
var score float64
err = rows.Scan(&id, &embedding, &jsonMetadata, &score)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

metadata := make(types.Meta)
err = json.Unmarshal(jsonMetadata, &metadata)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

var embeddingValues []float64
embeddingValues, err = valuesToFloats(embedding)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

result := index.SearchResult{
Data: index.Data{
ID: id,
Metadata: metadata,
Values: embeddingValues,
},
Score: score,
}
results = append(results, result)
}

err = rows.Err()
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

return results, nil
}

func (d *DB) createIndexIfRequired(ctx context.Context) error {
if d.createIndex == nil {
return nil
}

_, err := d.db.Exec("CREATE EXTENSION IF NOT EXISTS vector")
if err != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, err)
}

_, err = d.db.ExecContext(
ctx,
fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UUID PRIMARY KEY, metadata json, embedding vector(%d))",
d.table, d.createIndex.Dimension),
)
if err != nil {
return fmt.Errorf("%w: %w", index.ErrInternal, err)
}

return nil
}

func floatToValues(floats []float64) string {
var b strings.Builder
b.WriteString("[")
for i, f := range floats {
if i > 0 {
b.WriteString(",")
}
b.WriteString(strconv.FormatFloat(f, 'f', -1, 64))
}
b.WriteString("]")
return b.String()
}

func valuesToFloats(s string) ([]float64, error) {
s = strings.Trim(s, "[]")
parts := strings.Split(s, ",")
floats := make([]float64, len(parts))
for i, p := range parts {
f, err := strconv.ParseFloat(p, 64)
if err != nil {
return nil, err
}
floats[i] = f
}
return floats, nil
}
Loading

0 comments on commit 940fca7

Please sign in to comment.