diff --git a/demos/palm/python/docs-agent/README.md b/demos/palm/python/docs-agent/README.md index 2a724efd2..806f3c217 100644 --- a/demos/palm/python/docs-agent/README.md +++ b/demos/palm/python/docs-agent/README.md @@ -24,7 +24,7 @@ and is required that you have access to Google’s [PaLM API][genai-doc-site]. Keep in mind that this approach does not involve “fine-tuning” an LLM (large language model). Instead, the Docs Agent sample app uses a mixture of prompt engineering and embedding techniques, also known as Retrieval Augmented Generation (RAG), on top of a publicly available LLM model -like PaLM 2. +like PaLM 2. ![Docs Agent architecture](docs/images/docs-agent-architecture-01.png) @@ -210,10 +210,10 @@ by the PaLM model: - Additional condition (for fact-checking): ``` - Can you compare the text below to the context provided - in this prompt above and write a short message that warns the readers about - which part of the text they should consider fact-checking? (Please keep your - response concise and focus on only one important item.)" + Can you compare the text below to the information provided in this prompt above + and write a short message that warns the readers about which part of the text they + should consider fact-checking? (Please keep your response concise and focus on only + one important item.)" ``` - Previously generated response @@ -266,8 +266,7 @@ The following is the exact structure of this prompt: - Condition: ``` - You are a helpful chatbot answering questions from users. Read the following context first - and answer the question at the end: + Read the context below and answer the question at the end: ``` - Context: @@ -578,8 +577,10 @@ To customize settings in the Docs Agent chat app, do the following: condition for your custom dataset, for example: ``` - condition_text: "You are a helpful chatbot answering questions from developers working on - Flutter apps. Read the following context first and answer the question at the end:" + condition_text: "You are a helpful chatbot answering questions from **Flutter app developers**. + Read the context below first and answer the user's question at the end. + In your answer, provide a summary in three or five sentences. (BUT DO NOT USE + ANY INFORMATION YOU KNOW ABOUT THE WORLD.)" ``` ### 2. Launch the Docs Agent chat app diff --git a/demos/palm/python/docs-agent/chatbot/chatui.py b/demos/palm/python/docs-agent/chatbot/chatui.py index d6982cb7f..08b274762 100644 --- a/demos/palm/python/docs-agent/chatbot/chatui.py +++ b/demos/palm/python/docs-agent/chatbot/chatui.py @@ -25,6 +25,7 @@ json, ) import markdown +import markdown.extensions.fenced_code from bs4 import BeautifulSoup import urllib import os @@ -145,7 +146,9 @@ def ask_model(question): query_result = docs_agent.query_vector_store(question) context = query_result.fetch_formatted(Format.CONTEXT) context_with_instruction = docs_agent.add_instruction_to_context(context) - response = docs_agent.ask_text_model_with_context(context_with_instruction, question) + response = docs_agent.ask_text_model_with_context( + context_with_instruction, question + ) ### PROMPT 2: FACT-CHECK THE PREVIOUS RESPONSE. fact_checked_response = docs_agent.ask_text_model_to_fact_check( @@ -153,14 +156,21 @@ def ask_model(question): ) ### PROMPT 3: GET 5 RELATED QUESTIONS. - # 1. Prepare a new question asking the model to come up with 5 related questions. - # 2. Ask the language model with the new question. - # 3. Parse the model's response into a list in HTML format. + # 1. Use the response from Prompt 1 as context and add a custom condition. + # 2. Prepare a new question asking the model to come up with 5 related questions. + # 3. Ask the language model with the new question. + # 4. Parse the model's response into a list in HTML format. + new_condition = "Read the context below and answer the user's question at the end." + new_context_with_instruction = docs_agent.add_custom_instruction_to_context( + new_condition, response + ) new_question = ( "What are 5 questions developers might ask after reading the context?" ) new_response = markdown.markdown( - docs_agent.ask_text_model_with_context(response, new_question) + docs_agent.ask_text_model_with_context( + new_context_with_instruction, new_question + ) ) related_questions = parse_related_questions_response_to_html_list(new_response) @@ -181,8 +191,8 @@ def ask_model(question): # - Convert the fact-check response from the model into HTML for rendering. # - A workaround to get the server's URL to work with the rewrite and like features. new_uuid = uuid.uuid1() - context_in_html = markdown.markdown(context) - response_in_html = markdown.markdown(response) + context_in_html = markdown.markdown(context, extensions=["fenced_code"]) + response_in_html = markdown.markdown(response, extensions=["fenced_code"]) fact_checked_response_in_html = markdown.markdown(fact_checked_response) server_url = request.url_root.replace("http", "https") diff --git a/demos/palm/python/docs-agent/chatbot/static/css/style.css b/demos/palm/python/docs-agent/chatbot/static/css/style.css index 066e442f0..e59e7f982 100644 --- a/demos/palm/python/docs-agent/chatbot/static/css/style.css +++ b/demos/palm/python/docs-agent/chatbot/static/css/style.css @@ -67,6 +67,11 @@ li { margin: 0 0 0.3em; } +code { + font-family: math; + color: darkgreen; +} + /* ======= Style layout by ID ======= */ #callout-box { diff --git a/demos/palm/python/docs-agent/chroma.py b/demos/palm/python/docs-agent/chroma.py index c6d3a46a4..61d5c753d 100644 --- a/demos/palm/python/docs-agent/chroma.py +++ b/demos/palm/python/docs-agent/chroma.py @@ -47,7 +47,7 @@ def __init__(self, chroma_dir) -> None: def list_collections(self): return self.client.list_collections() - def get_collection(self, name, embedding_function=None): + def get_collection(self, name, embedding_function=None, embedding_model=None): if embedding_function is not None: return ChromaCollection( self.client.get_collection(name, embedding_function=embedding_function), @@ -55,9 +55,17 @@ def get_collection(self, name, embedding_function=None): ) # Read embedding meta information from the collection collection = self.client.get_collection(name, lambda x: None) - embedding_model = None - if collection.metadata: + if embedding_model is None and collection.metadata: embedding_model = collection.metadata.get("embedding_model", None) + if embedding_model is None: + # If embedding_model is not found in the metadata, + # use `models/embedding-gecko-001` by default. + logging.info( + "Embedding model is not specified in the metadata of " + "the collection %s. Using the default PaLM embedding model.", + name, + ) + embedding_model = "models/embedding-gecko-001" if embedding_model == "local/all-mpnet-base-v2": base_dir = os.path.dirname(os.path.abspath(__file__)) @@ -67,24 +75,19 @@ def get_collection(self, name, embedding_function=None): model_name=local_model_dir ) ) - elif embedding_model is None or embedding_model == "palm/embedding-gecko-001": - if embedding_model is None: - logging.info( - "Embedding model is not specified in the metadata of " - "the collection %s. Using the default PaLM embedding model.", - name, - ) - palm = PaLM(embed_model="models/embedding-gecko-001", find_models=False) - # We can not redefine embedding_function with def and - # have to assign a lambda to it - # pylint: disable-next=unnecessary-lambda-assignment - embedding_function = lambda texts: [palm.embed(text) for text in texts] - else: - raise ChromaEmbeddingModelNotSupportedError( - f"Embedding model {embedding_model} specified by collection {name} " - "is not supported." - ) + print("Embedding model: " + str(embedding_model)) + try: + palm = PaLM(embed_model=embedding_model, find_models=False) + # We cannot redefine embedding_function with def and + # have to assign a lambda to it + # pylint: disable-next=unnecessary-lambda-assignment + embedding_function = lambda texts: [palm.embed(text) for text in texts] + except: + raise ChromaEmbeddingModelNotSupportedError( + f"Embedding model {embedding_model} specified by collection {name} " + "is not supported." + ) return ChromaCollection( self.client.get_collection(name, embedding_function=embedding_function), diff --git a/demos/palm/python/docs-agent/config.yaml b/demos/palm/python/docs-agent/config.yaml index 44e5fb7e0..8435b17cc 100644 --- a/demos/palm/python/docs-agent/config.yaml +++ b/demos/palm/python/docs-agent/config.yaml @@ -16,6 +16,16 @@ ### Configuration for Docs Agent ### +### PaLM environment +# +# api_endpoint: The PaLM API endpoint used by Docs Agent. +# +# embedding_model: The PaLM embedding model used to generate embeddings. +# +api_endpoint: "generativelanguage.googleapis.com" +embedding_model: "models/embedding-gecko-001" + + ### Docs Agent environment # # product_name: The name of your product to appears on the chatbot UI. @@ -31,10 +41,14 @@ # collection_name: The name used to identify a dataset collection by # the Chroma vector database. # +# log_level: The verbosity level of logs printed on the terminal +# by the chatbot app: NORMAL or VERBOSE +# product_name: "My product" output_path: "data/plain_docs" vector_db_dir: "vector_stores/chroma" collection_name: "docs_collection" +log_level: "NORMAL" ### Documentation sources @@ -70,14 +84,17 @@ input: # model_error_message: The error message returned to the user when language # models are unable to provide responses. # -condition_text: "You are a helpful chatbot answering questions from users. Read -the following context first and answer the question at the end:" +condition_text: "You are a helpful chatbot answering questions from users. +Read the context below first and answer the user's question at the end. +In your answer, provide a summary in three or five sentences. (BUT DO NOT USE +ANY INFORMATION YOU KNOW ABOUT THE WORLD.)" -fact_check_question: "Can you compare the text below to the context provided -in this prompt above and write a short message that warns the readers about -which part of the text they should consider fact-checking? (Please keep your -response concise and focus on only one important item.)" +fact_check_question: "Can you compare the text below to the information +provided in this prompt above and write a short message that warns the readers +about which part of the text they should consider fact-checking? (Please keep +your response concise, focus on only one important item, but DO NOT USE BOLD +TEXT IN YOUR RESPONSE.)" -model_error_message: "PaLM is not able to answer this question at the -moment. Rephrase the question and try asking again." +model_error_message: "PaLM is not able to answer this question at the moment. +Rephrase the question and try asking again." diff --git a/demos/palm/python/docs-agent/docs/images/docs-agent-benchmarks-01.png b/demos/palm/python/docs-agent/docs/images/docs-agent-benchmarks-01.png new file mode 100644 index 000000000..2907ac307 Binary files /dev/null and b/demos/palm/python/docs-agent/docs/images/docs-agent-benchmarks-01.png differ diff --git a/demos/palm/python/docs-agent/docs/images/docs-agent-pre-processing-01.png b/demos/palm/python/docs-agent/docs/images/docs-agent-pre-processing-01.png index 7e81cccdf..98b34c02d 100644 Binary files a/demos/palm/python/docs-agent/docs/images/docs-agent-pre-processing-01.png and b/demos/palm/python/docs-agent/docs/images/docs-agent-pre-processing-01.png differ diff --git a/demos/palm/python/docs-agent/docs_agent.py b/demos/palm/python/docs-agent/docs_agent.py index 7f5f7b155..4f7c65694 100644 --- a/demos/palm/python/docs-agent/docs_agent.py +++ b/demos/palm/python/docs-agent/docs_agent.py @@ -34,15 +34,16 @@ # Select your PaLM API endpoint. PALM_API_ENDPOINT = "generativelanguage.googleapis.com" - -palm = PaLM(api_key=API_KEY, api_endpoint=PALM_API_ENDPOINT) - -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +EMBEDDING_MODEL = None # Set up the path to the chroma vector database. +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) LOCAL_VECTOR_DB_DIR = os.path.join(BASE_DIR, "vector_stores/chroma") COLLECTION_NAME = "docs_collection" +# Set the log level for the DocsAgent class: NORMAL or VERBOSE +LOG_LEVEL = "NORMAL" + IS_CONFIG_FILE = True if IS_CONFIG_FILE: config_values = read_config.ReadConfig() @@ -51,10 +52,16 @@ CONDITION_TEXT = config_values.returnConfigValue("condition_text") FACT_CHECK_QUESTION = config_values.returnConfigValue("fact_check_question") MODEL_ERROR_MESSAGE = config_values.returnConfigValue("model_error_message") + LOG_LEVEL = config_values.returnConfigValue("log_level") + PALM_API_ENDPOINT = config_values.returnConfigValue("api_endpoint") + EMBEDDING_MODEL = config_values.returnConfigValue("embedding_model") # Select the number of contents to be used for providing context. NUM_RETURNS = 5 +# Initialize the PaLM instance. +palm = PaLM(api_key=API_KEY, api_endpoint=PALM_API_ENDPOINT) + class DocsAgent: """DocsAgent class""" @@ -65,7 +72,9 @@ def __init__(self): "Using the local vector database created at %s", LOCAL_VECTOR_DB_DIR ) self.chroma = Chroma(LOCAL_VECTOR_DB_DIR) - self.collection = self.chroma.get_collection(COLLECTION_NAME) + self.collection = self.chroma.get_collection( + COLLECTION_NAME, embedding_model=EMBEDDING_MODEL + ) # Update PaLM's custom prompt strings self.prompt_condition = CONDITION_TEXT self.fact_check_question = FACT_CHECK_QUESTION @@ -74,6 +83,9 @@ def __init__(self): # Use this method for talking to PaLM (Text) def ask_text_model_with_context(self, context, question): new_prompt = f"{context}\n\nQuestion: {question}" + # Print the prompt for debugging if the log level is VERBOSE. + if LOG_LEVEL == "VERBOSE": + self.print_the_prompt(new_prompt) try: response = palm.generate_text( prompt=new_prompt, @@ -119,3 +131,24 @@ def add_instruction_to_context(self, context): new_context = "" new_context += self.prompt_condition + "\n\n" + context return new_context + + # Add custom instruction as a prefix to the context + def add_custom_instruction_to_context(self, condition, context): + new_context = "" + new_context += condition + "\n\n" + context + return new_context + + # Generate an embedding given text input + def generate_embedding(self, text): + return palm.embed(text) + + # Print the prompt on the terminal for debugging + def print_the_prompt(self, prompt): + print("#########################################") + print("# PROMPT #") + print("#########################################") + print(prompt) + print("#########################################") + print("# END OF PROMPT #") + print("#########################################") + print("\n") diff --git a/demos/palm/python/docs-agent/pyproject.toml b/demos/palm/python/docs-agent/pyproject.toml index 7b01146c8..42da764e5 100644 --- a/demos/palm/python/docs-agent/pyproject.toml +++ b/demos/palm/python/docs-agent/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "docs-agent" -version = "0.1.5" +version = "0.1.6" description = "" authors = ["Docs Agent contributors"] readme = "README.md" diff --git a/demos/palm/python/docs-agent/scripts/markdown_to_plain_text.py b/demos/palm/python/docs-agent/scripts/markdown_to_plain_text.py index 7b4f3721f..2121026ef 100644 --- a/demos/palm/python/docs-agent/scripts/markdown_to_plain_text.py +++ b/demos/palm/python/docs-agent/scripts/markdown_to_plain_text.py @@ -192,7 +192,7 @@ def process_page_and_section_titles(markdown_text): new_line = ( '# The "' + page_title - + '" page contains the following content:\n\n' + + '" page includes the following information:\n' ) if section_title: @@ -201,7 +201,7 @@ def process_page_and_section_titles(markdown_text): + page_title + '" page has the "' + section_title - + '" section that contains the following content:\n' + + '" section that includes the following information:\n' ) if subsection_title: @@ -212,7 +212,7 @@ def process_page_and_section_titles(markdown_text): + section_title + '" section has the "' + subsection_title - + '" subsection that contains the following content:\n' + + '" subsection that includes the following information:\n' ) if skip_this_line is False: diff --git a/demos/palm/python/docs-agent/scripts/populate_vector_database.py b/demos/palm/python/docs-agent/scripts/populate_vector_database.py index e6a0c5090..27ee767aa 100644 --- a/demos/palm/python/docs-agent/scripts/populate_vector_database.py +++ b/demos/palm/python/docs-agent/scripts/populate_vector_database.py @@ -52,6 +52,8 @@ ### Set up the path to the local LLM ### LOCAL_VECTOR_DB_DIR = os.path.join(BASE_DIR, "vector_stores/chroma") COLLECTION_NAME = "docs_collection" +PALM_API_ENDPOINT = "generativelanguage.googleapis.com" +EMBEDDING_MODEL = None IS_CONFIG_FILE = True if IS_CONFIG_FILE: @@ -60,6 +62,8 @@ input_len = config_values.returnInputCount() LOCAL_VECTOR_DB_DIR = config_values.returnConfigValue("vector_db_dir") COLLECTION_NAME = config_values.returnConfigValue("collection_name") + PALM_API_ENDPOINT = config_values.returnConfigValue("api_endpoint") + EMBEDDING_MODEL = config_values.returnConfigValue("embedding_model") ### Select the file index that is generated with your plain text files, same directory INPUT_FILE_INDEX = "file_index.json" @@ -94,13 +98,24 @@ msg = "The file " + FULL_INDEX_PATH + "does not exist." if EMBEDDINGS_TYPE == "PALM": - palm.configure(api_key=API_KEY) - # This returns models/embedding-gecko-001" + palm.configure(api_key=API_KEY, client_options={"api_endpoint": PALM_API_ENDPOINT}) + # Scan the list of PaLM models. models = [ m for m in palm.list_models() if "embedText" in m.supported_generation_methods ] - # MODEL = "models/embedding-gecko-001" - MODEL = models[0] + if EMBEDDING_MODEL != None: + # If `embedding_model` is specified in the `config.yaml` file, select that model. + found_model = False + for m in models: + if m.name == EMBEDDING_MODEL: + MODEL = m + print("[INFO] Embedding model is set to " + str(m.name) + "\n") + found_model = True + if found_model is False: + sys.exit("[ERROR] Cannot find the embedding model: " + str(EMBEDDING_MODEL)) + else: + # By default, pick the first model on the list (likely "models/embedding-gecko-001") + MODEL = models[0] elif EMBEDDINGS_TYPE == "LOCAL": MODEL = os.path.join(BASE_DIR, "models/all-mpnet-base-v2") emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL) @@ -110,6 +125,7 @@ chroma_client = chromadb.PersistentClient(path=LOCAL_VECTOR_DB_DIR) + # Create embed function for PaLM # API call limit to 5 qps @sleep_and_retry @@ -220,7 +236,7 @@ def embed_function(texts: Documents) -> Embeddings: match3 = re.search(r"(.*)\.md$", url) url = match3[1] # Replaces the URL if it comes from frontmatter - if (final_url): + if final_url: url = final_url_value # Creates a dictionary with basic metadata values # (i.e. source, URL, and md_hash)