Skip to content

Commit

Permalink
customize prompt template
Browse files Browse the repository at this point in the history
  • Loading branch information
lyliyu committed Dec 10, 2023
1 parent 6a1272c commit c3216a2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
15 changes: 13 additions & 2 deletions framework/feature_factory/llm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
TitleExtractor,
MetadataFeatureExtractor
)
from llama_index.node_parser.extractors.metadata_extractors import (
DEFAULT_TITLE_COMBINE_TEMPLATE,
DEFAULT_TITLE_NODE_TEMPLATE
)
from llama_index.text_splitter import TokenTextSplitter
from llama_index.schema import MetadataMode, Document as Document
from langchain.docstore.document import Document as LCDocument
Expand Down Expand Up @@ -182,15 +186,22 @@ def apply(self, filename: str) -> str:

class LlamaIndexTitleExtractor(LLMTool):

def __init__(self, llm_def, nodes) -> None:
def __init__(self, llm_def, nodes, prompt_template=DEFAULT_TITLE_NODE_TEMPLATE, combine_template=DEFAULT_TITLE_COMBINE_TEMPLATE) -> None:
super().__init__()
self.llm_def = llm_def
self.nodes = nodes
self.prompt_template = prompt_template
self.combine_template = combine_template

def create(self):
if super()._require_init():
self.llm_def.create()
self._instance = TitleExtractor(nodes=self.nodes, llm=self.llm_def.get_instance())
self._instance = TitleExtractor(
nodes=self.nodes,
llm=self.llm_def.get_instance(),
node_template=self.prompt_template,
combine_template= self.combine_template
)

def apply(self):
self.create()
Expand Down
22 changes: 19 additions & 3 deletions notebooks/feature_factory_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,20 @@ def apply(self):

# COMMAND ----------

title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def = MPT7b())
TITLE_NODE_TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nGive a title that summarizes this paragraph: {context_str}.\n### Response:\n"

# COMMAND ----------

TITLE_COMBINE_TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nGive a title that summarizes the following: {context_str}.\n### Response:\n"

# COMMAND ----------

title_extractor = LlamaIndexTitleExtractor(
nodes=5,
llm_def = MPT7b(),
prompt_template = TITLE_NODE_TEMPLATE,
combine_template = TITLE_COMBINE_TEMPLATE
)

# COMMAND ----------

Expand Down Expand Up @@ -89,17 +102,20 @@ def apply(self):

# COMMAND ----------

df = ff.assemble_llm_feature(spark, srcDirectory= "your source document directory", llmFeature=llm_feature, partitionNum=partition_num)
df = ff.assemble_llm_feature(spark, srcDirectory= "directory to your documents", llmFeature=llm_feature, partitionNum=partition_num)

# COMMAND ----------

display(df)

# COMMAND ----------

df.write.mode("overwrite").saveAsTable("<catalog>.<schema>.<table>")
df.write.mode("overwrite").saveAsTable("<catalog>.<database>.<table>")

# COMMAND ----------

# MAGIC %sql select * from liyu_demo.va.chunks

# COMMAND ----------


10 changes: 8 additions & 2 deletions test/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,16 @@ class TestCatalog(LLMCatalogBase):
doc_reader = LlamaIndexDocReader()

# define a text splitter
doc_splitter = LangChainRecursiveCharacterTextSplitter()
doc_splitter = LangChainRecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=64
)

# define a LLM feature, the name is the column name in the result dataframe
chunk_col_name = LLMFeature(reader=doc_reader, splitter=doc_splitter)
chunk_col_name = LLMFeature(
reader=doc_reader,
splitter=doc_splitter
)

llm_feature = TestCatalog.get_all_features()
assert llm_feature.name == "chunk_col_name"
Expand Down

0 comments on commit c3216a2

Please sign in to comment.