Skip to content

Commit

Permalink
Merge pull request #40 from jepler/improve-key-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jepler authored Oct 23, 2024
2 parents 8f126d3 + 4fc3607 commit 993b178
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
9 changes: 3 additions & 6 deletions src/chap/backends/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import httpx

from ..core import AutoAskMixin, Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User


class Anthropic(AutoAskMixin):
class Anthropic(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.anthropic.com"
model: str = "claude-3-5-sonnet-20240620"
max_new_tokens: int = 1000
api_key_name = "anthropic_api_key"

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -88,10 +89,6 @@ async def aask(

session.extend([User(query), Assistant("".join(new_content))])

@classmethod
def get_key(cls) -> str:
return get_key("anthropic_api_key")


def factory() -> Backend:
"""Uses the anthropic text-generation-interface web API"""
Expand Down
9 changes: 3 additions & 6 deletions src/chap/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import httpx

from ..core import AutoAskMixin, Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User


class HuggingFace(AutoAskMixin):
class HuggingFace(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api-inference.huggingface.co"
Expand All @@ -24,6 +24,7 @@ class Parameters:
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2
api_key_name = "huggingface_api_token"

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -110,10 +111,6 @@ async def aask(

session.extend([User(query), Assistant("".join(new_content))])

@classmethod
def get_key(cls) -> str:
return get_key("huggingface_api_token")


def factory() -> Backend:
"""Uses the huggingface text-generation-interface web API"""
Expand Down
9 changes: 3 additions & 6 deletions src/chap/backends/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import httpx

from ..core import AutoAskMixin
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Session, User


class Mistral(AutoAskMixin):
class Mistral(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.mistral.ai"
model: str = "open-mistral-7b"
max_new_tokens: int = 1000
api_key_name = "mistral_api_key"

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -91,9 +92,5 @@ async def aask(

session.extend([User(query), Assistant("".join(new_content))])

@classmethod
def get_key(cls) -> str:
return get_key("mistral_api_key")


factory = Mistral
13 changes: 7 additions & 6 deletions src/chap/backends/openai_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tiktoken

from ..core import Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Message, Session, User, session_to_list


Expand Down Expand Up @@ -63,7 +63,7 @@ def from_model(cls, model: str) -> "EncodingMeta":
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)


class ChatGPT:
class ChatGPT(UsesKeyMixin):
@dataclass
class Parameters:
model: str = "gpt-4o-mini"
Expand All @@ -81,6 +81,11 @@ class Parameters:
top_p: float | None = None
"""The model temperature for sampling"""

api_key_name: str = "openai_api_key"
"""The OpenAI API key"""

parameters: Parameters

def __init__(self) -> None:
self.parameters = self.Parameters()

Expand Down Expand Up @@ -171,10 +176,6 @@ async def aask(

session.extend([User(query), Assistant("".join(new_content))])

@classmethod
def get_key(cls) -> str:
return get_key("openai_api_key")


def factory() -> Backend:
"""Uses the OpenAI chat completion API"""
Expand Down
58 changes: 47 additions & 11 deletions src/chap/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,61 @@
#
# SPDX-License-Identifier: MIT

import json
import subprocess
from typing import Protocol
import functools

import platformdirs


class APIKeyProtocol(Protocol):
@property
def api_key_name(self) -> str:
...


class HasKeyProtocol(Protocol):
@property
def parameters(self) -> APIKeyProtocol:
...


class UsesKeyMixin:
def get_key(self: HasKeyProtocol) -> str:
return get_key(self.parameters.api_key_name)


class NoKeyAvailable(Exception):
pass


_key_path_base = platformdirs.user_config_path("chap")


@functools.cache
def get_key(name: str, what: str = "openai api key") -> str:
key_path = _key_path_base / name
if not key_path.exists():
raise NoKeyAvailable(
f"Place your {what} in {key_path} and run the program again"
)

with open(key_path, encoding="utf-8") as f:
return f.read().strip()
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"

if USE_PASSWORD_STORE.exists():
content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
if content.strip():
cfg = json.loads(content)
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")

@functools.cache
def get_key(name: str, what: str = "api key") -> str:
key_path = f"{pass_prefix}{name}"
command = pass_command + [key_path]
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]

else:

@functools.cache
def get_key(name: str, what: str = "api key") -> str:
key_path = _key_path_base / name
if not key_path.exists():
raise NoKeyAvailable(
f"Place your {what} in {key_path} and run the program again"
)

with open(key_path, encoding="utf-8") as f:
return f.read().strip()

0 comments on commit 993b178

Please sign in to comment.