Skip to content

Commit

Permalink
add capacity for self-hosted llama model (#57)
Browse files Browse the repository at this point in the history
* add capacity for self-hosted llama model

* fix linting errors
  • Loading branch information
MichaelClifford authored Oct 10, 2023
1 parent 89201dd commit d52c098
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/.env.config.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
MODEL_PROVIDER=vertex
#MODEL_PROVIDER=openai
#MODEL_PROVIDER=hosted
MODEL_TEMPERATURE=0.0

### Hosted Model ###
HOSTED_MODEL_URI="http://somehosted-model-uri"

### Vertex AI ###
VERTEX_PROJECT_ID=shadowbot-YOURNAME
VERTEX_REGION=us-central1
Expand Down
53 changes: 53 additions & 0 deletions src/hosted_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Module for handling self-hosted LLama2 models"""

from typing import Any, List, Mapping, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output_parser import BaseOutputParser


class HostedLLM(LLM):
"""
Class to define interaction with the hosted LLM at a specified URI
"""
uri: str

@property
def _llm_type(self) -> str:
return "custom"

def _call(self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response = requests.get(self.uri,
params={"text" : prompt},timeout=600)
if response.status_code == 200:
return str(response.content)
return f"Model Server is not Working due to error {response.status_code}"


@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"uri": self.uri}

class CustomLlamaParser(BaseOutputParser[str]): # pylint: disable=R0903
"""Class to correctly parse model outputs"""

def parse(self, text:str) -> str:
"""Parse the output of our LLM"""
if text.startswith("Model Server is not Working due"):
return text
cleaned = str(text).split("[/INST]")
return cleaned[1]

@property
def _type(self) -> str:
return "custom_output_parser"

27 changes: 26 additions & 1 deletion src/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from fastapi import APIRouter, Body, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611

from langchain.llms import OpenAI
from langchain.llms import VertexAI
from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from src.hosted_llm import HostedLLM
from src.hosted_llm import CustomLlamaParser
from src.config import Config
from src.embeddings import EmbeddingSource
from src.logging_setup import setup_logger
Expand Down Expand Up @@ -42,10 +43,31 @@ def call_language_model(input_val):
result = call_openai(input_val, prompt)
elif model_provider == 'vertex':
result = call_vertexai(input_val, prompt)
elif model_provider == 'hosted':
result = call_hosted_llm(input_val, prompt)
else:
raise ValueError(f"Invalid model name: {model_provider}")
return result


def call_hosted_llm(input_val, prompt):
"""Call the hosted language model and return the result.
Args:
input_val: The input value to pass to the language model.
Returns:
The result from the language model.
"""
hosted_model_name = config.get("HOSTED_MODEL_NAME", "Llama2-Hosted")
logger.debug("Using self-hosted model: %s", hosted_model_name)
hosted_model_uri = config.get("HOSTED_MODEL_URI", None)
llm = HostedLLM(uri=hosted_model_uri)
chain = LLMChain(llm=llm, prompt=prompt, output_parser=CustomLlamaParser())
result = chain.run(input_val)
return result


def call_vertexai(input_val, prompt):
"""Call the Vertex AI language model and return the result.
Expand Down Expand Up @@ -256,6 +278,7 @@ def synthesize_response(

if prompt is None:
prompt = (
"<s>[INST] <<SYS>> \n"
"Below is the only information you know.\n"
"It was obtained by doing a vector search for the user's query:\n\n"
"---START INFO---\n\n{embedding_results}\n\n"
Expand All @@ -266,6 +289,8 @@ def synthesize_response(
"Use no other knowledge to respond. Do not make anything up. "
"You can let the reader know if do not think you have enough information "
"to respond to their query...\n\n"
"<</SYS>>"
f"{query} [/INST]"
)

prompt = prompt.format(embedding_results=embedding_results_text)
Expand Down

0 comments on commit d52c098

Please sign in to comment.