Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kxtran committed Jun 12, 2024
1 parent 2bf4800 commit 1cb5bc0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 27 deletions.
77 changes: 53 additions & 24 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,18 @@ def _init_log_row(request: Request):
return log_row


def check_provider_request(request: Request):
def get_completion_id(request: Request):
host = request.headers.get("host")
if "anthropic" in host:
paths = ["/v1/messages", "/v1/complete"]
if not any(path in str(request.url) for path in paths):
logger.warning("Currently logging is only available for anthropic v1/messages and v1/complete.")
logger.debug("Currently logging is only available for anthropic v1/messages and v1/complete.")
return

if "openai" in host and "v1/chat/completions" not in str(request.url):
logger.warning("Currently logging is only available for openai v1/chat/completions.")
logger.debug("Currently logging is only available for openai v1/chat/completions.")
return


def get_completion_id(request: Request):
completion_id = str(uuid.uuid4())
request.headers["x-log10-completion-id"] = completion_id
last_completion_response_var.set({"completionID": completion_id})
Expand Down Expand Up @@ -309,7 +307,13 @@ def patch_response(log_row: dict, llm_response: dict, request: Request):
if "v1/messages" in str(request.url):
llm_response = Anthropic.prepare_response(Message(**llm_response))
elif "v1/complete" in str(request.url):
llm_response = Anthropic.prepare_response(Completion(**llm_response))
prompt = ""
if request.content:
content = request.content.decode("utf-8")
content_json = json.loads(content)
prompt = content_json.get("prompt", "")

llm_response = Anthropic.prepare_response(Completion(**llm_response), input_prompt=prompt)
else:
logger.warning("Currently logging is only available for anthropic v1/messages and v1/complete.")

Expand All @@ -323,8 +327,16 @@ def patch_response(log_row: dict, llm_response: dict, request: Request):
return log_row


class _EventHookManager:
class _RequestHooks:
"""
The class to manage the event hooks for sync requests and initialize the log row.
The event hooks are:
- get_completion_id: to generate the completion id
- log_request: to send the sync request with initial log row to the log10 platform
"""

def __init__(self):
logger.debug("LOG10: initializing request hooks")
self.event_hooks = {
"request": [self.get_completion_id, self.log_request],
}
Expand All @@ -333,7 +345,6 @@ def __init__(self):

def get_completion_id(self, request: httpx.Request):
logger.debug("LOG10: generating completion id")
check_provider_request(request)
self.completion_id = get_completion_id(request)

def log_request(self, request: httpx.Request):
Expand All @@ -342,9 +353,16 @@ def log_request(self, request: httpx.Request):
_try_post_request(url=f"{base_url}/api/completions/{self.completion_id}", payload=self.log_row)


class _AsyncEventHookManager:
class _AsyncRequestHooks:
"""
The class to manage the event hooks for async requests and initialize the log row.
The event hooks are:
- get_completion_id: to generate the completion id
- log_request: to send the sync request with initial log row to the log10 platform
"""

def __init__(self):
logger.debug("LOG10: initializing async event hook manager")
logger.debug("LOG10: initializing async request hooks")
self.event_hooks = {
"request": [self.get_completion_id, self.log_request],
}
Expand All @@ -353,7 +371,6 @@ def __init__(self):

async def get_completion_id(self, request: httpx.Request):
logger.debug("LOG10: generating completion id")
check_provider_request(request)
self.completion_id = get_completion_id(request)

async def log_request(self, request: httpx.Request):
Expand Down Expand Up @@ -614,7 +631,7 @@ def parse_response_data(self, responses: list[str]):


class _LogTransport(httpx.BaseTransport):
def __init__(self, transport: httpx.BaseTransport, event_hook_manager: _EventHookManager):
def __init__(self, transport: httpx.BaseTransport, event_hook_manager: _RequestHooks):
self.transport = transport
self.event_hook_manager = event_hook_manager

Expand Down Expand Up @@ -654,7 +671,7 @@ def handle_request(self, request: Request) -> Response:


class _AsyncLogTransport(httpx.AsyncBaseTransport):
def __init__(self, transport: httpx.AsyncBaseTransport, event_hook_manager: _AsyncEventHookManager):
def __init__(self, transport: httpx.AsyncBaseTransport, event_hook_manager: _AsyncRequestHooks):
self.transport = transport
self.event_hook_manager = event_hook_manager

Expand Down Expand Up @@ -696,19 +713,30 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:


class InitPatcher:
def __init__(self, module, async_class_name, sync_class_name=None):
def __init__(self, module, class_names: list[str]):
logger.debug("LOG10: initializing patcher")
self.module = module
self.sync_class_name = sync_class_name
self.async_class_name = async_class_name
self.origin_init = getattr(module, sync_class_name).__init__ if sync_class_name else None
self.async_origin_init = getattr(module, async_class_name).__init__
self.patch_init()
if len(class_names) > 2:
raise ValueError("Only two class names (sync and async) are allowed")

self.async_class_name = None
self.sync_class_name = None

for class_name in class_names:
if class_name.startswith("Async"):
self.async_class_name = class_name
self.async_origin_init = getattr(module, self.async_class_name).__init__
else:
self.sync_class_name = class_name
self.origin_init = getattr(module, self.sync_class_name).__init__

self._patch_init()

def patch_init(self):
def _patch_init(self):
def new_init(instance, *args, **kwargs):
logger.debug(f"LOG10: patching {self.sync_class_name}.__init__")

event_hook_manager = _EventHookManager()
event_hook_manager = _RequestHooks()
httpx_client = httpx.Client(
event_hooks=event_hook_manager.event_hooks,
transport=_LogTransport(httpx.HTTPTransport(), event_hook_manager),
Expand All @@ -719,7 +747,7 @@ def new_init(instance, *args, **kwargs):
def async_new_init(instance, *args, **kwargs):
logger.debug(f"LOG10: patching {self.async_class_name}.__init__")

event_hook_manager = _AsyncEventHookManager()
event_hook_manager = _AsyncRequestHooks()
async_httpx_client = httpx.AsyncClient(
event_hooks=event_hook_manager.event_hooks,
transport=_AsyncLogTransport(httpx.AsyncHTTPTransport(), event_hook_manager),
Expand All @@ -728,8 +756,9 @@ def async_new_init(instance, *args, **kwargs):
self.async_origin_init(instance, *args, **kwargs)

# Patch the asynchronous class __init__
async_class = getattr(self.module, self.async_class_name)
async_class.__init__ = async_new_init
if self.async_class_name:
async_class = getattr(self.module, self.async_class_name)
async_class.__init__ = async_new_init

# Patch the synchronous class __init__ if provided
if self.sync_class_name:
Expand Down
66 changes: 63 additions & 3 deletions log10/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def set_sync_log_text(USE_ASYNC=True):

def log10(module, DEBUG_=False, USE_ASYNC_=True):
"""Intercept and overload module for logging purposes
support both openai V0 and V1, vertexai, and mistralai
support both openai V0 and V1, anthropic, vertexai, and mistralai
Keyword arguments:
module -- the module to be intercepted (e.g. openai)
Expand Down Expand Up @@ -793,6 +793,33 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True):
>>> ]
>>> completion = llm.predict_messages(messages)
>>> print(completion)
Example:
>>> from log10.load import log10
>>> import anthropic
>>> log10(anthropic)
>>> from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
>>> anthropic = Anthropic()
>>> completion = anthropic.completions.create(
>>> model="claude-1",
>>> max_tokens_to_sample=32,
>>> prompt=f"{HUMAN_PROMPT} Hi, how are you? {AI_PROMPT}",
>>> )
>>> print(completion.completion)
Example:
>>> from log10.load import log10
>>> import anthropic
>>> from langchain.chat_models import ChatAnthropic
>>> from langchain.schema import HumanMessage, SystemMessage
>>> log10(anthropic)
>>> llm = ChatAnthropic(model="claude-1", temperature=0.7)
>>> messages = [
>>> SystemMessage(content="You are a ping pong machine"),
>>> HumanMessage(content="Ping?")
>>> ]
>>> completion = llm.predict_messages(messages)
>>> print(completion)
"""
global DEBUG, USE_ASYNC, sync_log_text
DEBUG = DEBUG_ or os.environ.get("LOG10_DEBUG", False)
Expand Down Expand Up @@ -827,7 +854,7 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True):
from log10._httpx_utils import InitPatcher

# Patch the AsyncAnthropic and Anthropic class
InitPatcher(module, "AsyncAnthropic", "Anthropic")
InitPatcher(module, ["AsyncAnthropic", "Anthropic"])
elif module.__name__ == "lamini":
attr = module.api.utils.completion.Completion
method = getattr(attr, "generate")
Expand Down Expand Up @@ -858,7 +885,7 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True):
from log10._httpx_utils import InitPatcher

# Patch the AsyncOpenAI class
InitPatcher(module, "AsyncOpenAI")
InitPatcher(module, ["AsyncOpenAI"])
else:
attr = module.api_resources.completion.Completion
method = getattr(attr, "create")
Expand Down Expand Up @@ -923,3 +950,36 @@ def __init__(self, *args, **kwargs):
if not getattr(openai, "_log10_patched", False):
log10(openai)
openai._log10_patched = True


try:
import anthropic
except ImportError:
logger.warning("Anthropic not found. Skipping defining log10.load.Anthropic client.")
else:
from anthropic import Anthropic

class Anthropic(Anthropic):
"""
Example:
>>> from log10.load import Anthropic
>>> client = Anthropic(tags=["test", "load_anthropic"])
>>> message = client.messages.create(
... model="claude-3-haiku-20240307",
... max_tokens=100,
... temperature=0.9,
... system="Respond only in Yoda-speak.",
... messages=[{"role": "user", "content": "How are you today?"}],
... )
>>> print(message.content[0].text)
"""

def __init__(self, *args, **kwargs):
if "tags" in kwargs:
tags_var.set(kwargs.pop("tags"))

if not getattr(anthropic, "_log10_patched", False):
log10(anthropic)
anthropic._log10_patched = True

super().__init__(*args, **kwargs)

0 comments on commit 1cb5bc0

Please sign in to comment.