-
Notifications
You must be signed in to change notification settings - Fork 0
/
cvpartner_qa.py
81 lines (67 loc) · 2.47 KB
/
cvpartner_qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.pydantic_v1 import Extra
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.question_answering import load_qa_chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores.chroma import Chroma
class CVPartnerQA(Chain):
combine_documents_chain: BaseCombineDocumentsChain
input_key_list: List[str] = ["query", "email"] #: :meta private:
output_key: str = "result" #: :meta private:
vector_store: Chroma
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
return self.input_key_list
@property
def output_keys(self) -> List[str]:
return [self.output_key]
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
vector_store: Chroma,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
):
"""Load chain from chain type."""
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(
combine_documents_chain=combine_documents_chain,
vector_store=vector_store,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs["query"]
email = inputs["email"]
docs = await self.vector_store.amax_marginal_relevance_search(
question, k=5, filter={"email": email.lower()}
)
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
return {self.output_key: answer}