From c3216a2afc8c5430a807b6816700b27b8b0fe68f Mon Sep 17 00:00:00 2001 From: Li Yu Date: Sun, 10 Dec 2023 15:33:46 -0500 Subject: [PATCH] customize prompt template --- framework/feature_factory/llm_tools.py | 15 +++++++++++++-- notebooks/feature_factory_llms.py | 22 +++++++++++++++++++--- test/test_chunking.py | 10 ++++++++-- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/framework/feature_factory/llm_tools.py b/framework/feature_factory/llm_tools.py index 246804e..cfea585 100644 --- a/framework/feature_factory/llm_tools.py +++ b/framework/feature_factory/llm_tools.py @@ -7,6 +7,10 @@ TitleExtractor, MetadataFeatureExtractor ) +from llama_index.node_parser.extractors.metadata_extractors import ( + DEFAULT_TITLE_COMBINE_TEMPLATE, + DEFAULT_TITLE_NODE_TEMPLATE +) from llama_index.text_splitter import TokenTextSplitter from llama_index.schema import MetadataMode, Document as Document from langchain.docstore.document import Document as LCDocument @@ -182,15 +186,22 @@ def apply(self, filename: str) -> str: class LlamaIndexTitleExtractor(LLMTool): - def __init__(self, llm_def, nodes) -> None: + def __init__(self, llm_def, nodes, prompt_template=DEFAULT_TITLE_NODE_TEMPLATE, combine_template=DEFAULT_TITLE_COMBINE_TEMPLATE) -> None: super().__init__() self.llm_def = llm_def self.nodes = nodes + self.prompt_template = prompt_template + self.combine_template = combine_template def create(self): if super()._require_init(): self.llm_def.create() - self._instance = TitleExtractor(nodes=self.nodes, llm=self.llm_def.get_instance()) + self._instance = TitleExtractor( + nodes=self.nodes, + llm=self.llm_def.get_instance(), + node_template=self.prompt_template, + combine_template= self.combine_template + ) def apply(self): self.create() diff --git a/notebooks/feature_factory_llms.py b/notebooks/feature_factory_llms.py index 5f65cdc..fae087d 100644 --- a/notebooks/feature_factory_llms.py +++ b/notebooks/feature_factory_llms.py @@ -58,7 +58,20 @@ def apply(self): # COMMAND ---------- -title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def = MPT7b()) +TITLE_NODE_TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nGive a title that summarizes this paragraph: {context_str}.\n### Response:\n" + +# COMMAND ---------- + +TITLE_COMBINE_TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nGive a title that summarizes the following: {context_str}.\n### Response:\n" + +# COMMAND ---------- + +title_extractor = LlamaIndexTitleExtractor( + nodes=5, + llm_def = MPT7b(), + prompt_template = TITLE_NODE_TEMPLATE, + combine_template = TITLE_COMBINE_TEMPLATE +) # COMMAND ---------- @@ -89,7 +102,7 @@ def apply(self): # COMMAND ---------- -df = ff.assemble_llm_feature(spark, srcDirectory= "your source document directory", llmFeature=llm_feature, partitionNum=partition_num) +df = ff.assemble_llm_feature(spark, srcDirectory= "directory to your documents", llmFeature=llm_feature, partitionNum=partition_num) # COMMAND ---------- @@ -97,9 +110,12 @@ def apply(self): # COMMAND ---------- -df.write.mode("overwrite").saveAsTable("..") +df.write.mode("overwrite").saveAsTable("..
") # COMMAND ---------- +# MAGIC %sql select * from liyu_demo.va.chunks + +# COMMAND ---------- diff --git a/test/test_chunking.py b/test/test_chunking.py index 90dfce2..1db6a5b 100644 --- a/test/test_chunking.py +++ b/test/test_chunking.py @@ -152,10 +152,16 @@ class TestCatalog(LLMCatalogBase): doc_reader = LlamaIndexDocReader() # define a text splitter - doc_splitter = LangChainRecursiveCharacterTextSplitter() + doc_splitter = LangChainRecursiveCharacterTextSplitter( + chunk_size=1024, + chunk_overlap=64 + ) # define a LLM feature, the name is the column name in the result dataframe - chunk_col_name = LLMFeature(reader=doc_reader, splitter=doc_splitter) + chunk_col_name = LLMFeature( + reader=doc_reader, + splitter=doc_splitter + ) llm_feature = TestCatalog.get_all_features() assert llm_feature.name == "chunk_col_name"