diff --git a/framework/feature_factory/llm_tools.py b/framework/feature_factory/llm_tools.py index b9bf844..246804e 100644 --- a/framework/feature_factory/llm_tools.py +++ b/framework/feature_factory/llm_tools.py @@ -4,7 +4,8 @@ from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser.extractors import ( MetadataExtractor, - TitleExtractor + TitleExtractor, + MetadataFeatureExtractor ) from llama_index.text_splitter import TokenTextSplitter from llama_index.schema import MetadataMode, Document as Document @@ -24,6 +25,7 @@ class LLMTool(ABC): """ def __init__(self) -> None: self._initialized = False + self.instance = None def _require_init(self) -> bool: if self._initialized: @@ -40,6 +42,9 @@ def apply(self): def create(self): ... + def get_instance(self): + return self._instance + class DocReader(LLMTool): """ Generic class for doc reader. @@ -165,16 +170,31 @@ def apply(self, filename: str) -> str: -class LLMDef(LLMTool): - """ A generic class to define LLM instance e.g. using HuggingFace APIs. - An example can be found at notebooks/feature_factory_llms.py - """ - def __init__(self) -> None: - self._instance = None +# class LLMDef(LLMTool): +# """ A generic class to define LLM instance e.g. using HuggingFace APIs. +# An example can be found at notebooks/feature_factory_llms.py +# """ +# def __init__(self) -> None: +# self._instance = None - def get_instance(self): - return self._instance +# def get_instance(self): +# return self._instance + +class LlamaIndexTitleExtractor(LLMTool): + + def __init__(self, llm_def, nodes) -> None: + super().__init__() + self.llm_def = llm_def + self.nodes = nodes + + def create(self): + if super()._require_init(): + self.llm_def.create() + self._instance = TitleExtractor(nodes=self.nodes, llm=self.llm_def.get_instance()) + def apply(self): + self.create() + class LlamaIndexDocSplitter(DocSplitter): @@ -183,23 +203,23 @@ class LlamaIndexDocSplitter(DocSplitter): `chunk_size`, `chunk_overlap` are the super parameters to tweak for better response from LLMs. `llm` is the LLM instance used for metadata extraction. If not provided, the splitter will generate text chunks only. """ - def __init__(self, chunk_size:int=1024, chunk_overlap:int=64, llm:LLMDef=None) -> None: + def __init__(self, chunk_size:int=1024, chunk_overlap:int=64, extractors:List[LLMTool]=None) -> None: super().__init__() self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - self.llm = llm + self.extractors = extractors def create(self): if super()._require_init(): text_splitter = TokenTextSplitter( separator=" ", chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) - if self.llm: - self.llm.create() + if self.extractors: + for extractor in self.extractors: + extractor.create() + extractor_instances = [e.get_instance() for e in self.extractors] metadata_extractor = MetadataExtractor( - extractors=[ - TitleExtractor(nodes=5, llm=self.llm.get_instance()) - ], + extractors=extractor_instances, in_place=False, ) else: diff --git a/notebooks/feature_factory_llms.py b/notebooks/feature_factory_llms.py index 92b98c5..5f65cdc 100644 --- a/notebooks/feature_factory_llms.py +++ b/notebooks/feature_factory_llms.py @@ -1,5 +1,5 @@ # Databricks notebook source -# MAGIC %pip install llama-index pypdf +# MAGIC %pip install llama-index==0.8.61 pypdf # COMMAND ---------- @@ -11,7 +11,11 @@ # COMMAND ---------- -from framework.feature_factory.llm_tools import LLMFeature, LlamaIndexDocReader, LlamaIndexDocSplitter, LLMDef +# MAGIC %pip list + +# COMMAND ---------- + +from framework.feature_factory.llm_tools import LLMFeature, LlamaIndexDocReader, LlamaIndexDocSplitter, LLMTool, LangChainRecursiveCharacterTextSplitter, LlamaIndexTitleExtractor from framework.feature_factory import Feature_Factory import torch from llama_index.llms import HuggingFaceLLM @@ -22,7 +26,7 @@ # COMMAND ---------- -class MPT7b(LLMDef): +class MPT7b(LLMTool): def create(self): torch.cuda.empty_cache() generate_params = { @@ -35,7 +39,7 @@ def create(self): "pad_token_id": 0 } - self._instance = HuggingFaceLLM( + llm = HuggingFaceLLM( max_new_tokens=256, generate_kwargs=generate_params, # system_prompt=system_prompt, @@ -46,21 +50,33 @@ def create(self): tokenizer_kwargs={"max_length": 1024}, model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True} ) - return None + self._instance = llm + return llm def apply(self): ... # COMMAND ---------- +title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def = MPT7b()) + +# COMMAND ---------- + doc_splitter = LlamaIndexDocSplitter( chunk_size = 1024, chunk_overlap = 32, - llm = MPT7b() + extractors = [title_extractor] ) # COMMAND ---------- +# doc_splitter = LangChainRecursiveCharacterTextSplitter( +# chunk_size = 1024, +# chunk_overlap = 32 +# ) + +# COMMAND ---------- + llm_feature = LLMFeature ( name = "chunks", reader = LlamaIndexDocReader(), @@ -73,7 +89,7 @@ def apply(self): # COMMAND ---------- -df = ff.assemble_llm_feature(spark, srcDirectory= "/dbfs/tmp/li_yu/va_llms/pdf", llmFeature=llm_feature, partitionNum=partition_num) +df = ff.assemble_llm_feature(spark, srcDirectory= "your source document directory", llmFeature=llm_feature, partitionNum=partition_num) # COMMAND ---------- @@ -81,4 +97,9 @@ def apply(self): # COMMAND ---------- +df.write.mode("overwrite").saveAsTable("..") + +# COMMAND ---------- + + diff --git a/test/test_chunking.py b/test/test_chunking.py index 258f54e..90dfce2 100644 --- a/test/test_chunking.py +++ b/test/test_chunking.py @@ -10,6 +10,8 @@ from framework.feature_factory.catalog import LLMCatalogBase from enum import IntEnum from framework.feature_factory.llm_tools import * +from llama_index.llms import HuggingFaceLLM +import torch class TestLLMTools(unittest.TestCase): @@ -19,6 +21,35 @@ def test_llamaindex_reader(self): doc_reader.create() docs = doc_reader.apply("test/data/sample.pdf") assert len(docs) == 2 + + def test_metadata_extractor(self): + class MPT7b(LLMTool): + def create(self): + generate_params = { + "temperature": 1.0, + "top_p": 1.0, + "top_k": 50, + "use_cache": True, + "do_sample": True, + "eos_token_id": 0, + "pad_token_id": 0 + } + + self._instance = HuggingFaceLLM( + max_new_tokens=256, + generate_kwargs=generate_params, + tokenizer_name="mosaicml/mpt-7b-instruct", + model_name="mosaicml/mpt-7b-instruct", + device_map="auto", + tokenizer_kwargs={"max_length": 1024}, + model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True} + ) + return None + def apply(self): + ... + + title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def=MPT7b()) + assert title_extractor.nodes == 5 and isinstance(title_extractor.llm_def, MPT7b) def test_llamaindex_splitter(self): doc_reader = LlamaIndexDocReader()