Skip to content

Commit

Permalink
Feat: add new QA pipeline mode refine (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Sep 16, 2023
1 parent 80ab33a commit 96e602d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/embeddings/simplekb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func main() {
docs, _ := loader.NewPDFToTextLoader("./kb").WithPDFToTextPath("/opt/homebrew/bin/pdftotext").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background())
index := simplevectorindex.New("db", ".", openaiembedder.New(openaiembedder.AdaEmbeddingV2))
index.LoadFromDocuments(context.Background(), docs)
qapipeline.New(openai.NewChat().WithVerbose(true)).WithIndex(index).Query(context.Background(), "What is the NATO purpose?", option.WithTopK(1))
Expand Down
87 changes: 80 additions & 7 deletions pipeline/qa/qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ const (

//nolint:lll
qaTubeUserPromptTemplate = "Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}"

qaTubePromptTemplate = "Context information is below\n--------------------\n{{.context}}\n" +
"--------------------\nGiven the context information and not prior knowledge, answer the following question.\n" +
"Question: {{.query}}\nAnswer:"
qaTubePromptRefineTemplate = "The original question is as follows: {{.query}}\n" +
"We have provided an existing answer: {{.answer}}\n" +
"We have the opportunity to refine the existing answer (only if needed) with some more context below.\n" +
"----------------------\n{{.context}}\n----------------------\n" +
"Given the new context, refine the original answer to better answer the question.\n" +
"If the context isn't useful, return the original answer.\n" +
"Refined answer:"
)

type Mode string

const (
ModeSimple Mode = "simple"
ModeRefine Mode = "refine"
)

type Index interface {
Expand All @@ -29,6 +47,7 @@ type Index interface {
type QAPipeline struct {
llmEngine pipeline.LlmEngine
pipeline *pipeline.Pipeline
mode Mode
index Index
}

Expand Down Expand Up @@ -58,6 +77,7 @@ func New(llmEngine pipeline.LlmEngine) *QAPipeline {
llmEngine: llmEngine,
pipeline: pipeline.New(tube),
index: nil,
mode: ModeSimple,
}
}

Expand All @@ -80,6 +100,11 @@ func (q *QAPipeline) WithIndex(index Index) *QAPipeline {
return q
}

func (q *QAPipeline) WithMode(mode Mode) *QAPipeline {
q.mode = mode
return q
}

func (q *QAPipeline) Query(ctx context.Context, query string, opts ...indexoption.Option) (types.M, error) {
if q.index == nil {
return nil, fmt.Errorf("retriever is not defined")
Expand All @@ -99,11 +124,59 @@ func (q *QAPipeline) Run(ctx context.Context, query string, documents []document
content += document.Content + "\n"
}

return q.pipeline.Run(
ctx,
types.M{
"query": query,
"context": content,
},
)
if q.mode == ModeSimple {
return q.pipeline.Run(
ctx,
types.M{
"query": query,
"context": content,
},
)
}

return q.runRefine(ctx, query, documents)
}

func (q *QAPipeline) runRefine(ctx context.Context, query string, documents []document.Document) (types.M, error) {
var currentResponse string
var output types.M
var err error

for i, document := range documents {
context := document.Content

var qaPrompt *prompt.Template
if i == 0 {
qaPrompt = prompt.NewPromptTemplate(qaTubePromptTemplate)
} else {
qaPrompt = prompt.NewPromptTemplate(qaTubePromptRefineTemplate)
}

llm := pipeline.Llm{
LlmEngine: q.llmEngine,
LlmMode: pipeline.LlmModeCompletion,
Prompt: qaPrompt,
}

tube := pipeline.NewTube(llm)
q.pipeline = pipeline.New(tube)
output, err = q.pipeline.Run(
ctx, types.M{
"query": query,
"answer": currentResponse,
"context": context,
},
)
if err != nil {
return nil, err
}

response, ok := output[types.DefaultOutputKey].(string)
if !ok {
return nil, fmt.Errorf("invalid response type")
}
currentResponse = response
}

return output, nil
}

0 comments on commit 96e602d

Please sign in to comment.