Skip to content

Commit

Permalink
Merge branch 'main' into llm_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
srdas authored Sep 16, 2024
2 parents 92f0b29 + 7b4a708 commit f7be16e
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 63 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/e2e-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ name: E2E Tests

# suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292
env:
JUPYTER_PLATFORM_DIRS: '1'
JUPYTER_PLATFORM_DIRS: "1"

on:
push:
branches: main
pull_request:
branches: '*'
branches: "*"

jobs:
e2e-tests:
Expand Down Expand Up @@ -41,17 +41,17 @@ jobs:
${{ github.workspace }}/pw-browsers
key: ${{ runner.os }}-${{ hashFiles('packages/jupyter-ai/ui-tests/yarn.lock') }}

- name: Install browser
- name: Install Chromium
working-directory: packages/jupyter-ai/ui-tests
run: jlpm install-chromium

- name: Execute e2e tests
- name: Run E2E tests
working-directory: packages/jupyter-ai/ui-tests
run: jlpm test

- name: Upload Playwright Test report
- name: Upload Playwright test report
if: always()
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: jupyter-ai-playwright-tests-linux
path: |
Expand Down
23 changes: 21 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Python Unit Tests
name: Python Tests

# suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292
env:
Expand All @@ -12,7 +12,7 @@ on:

jobs:
unit-tests:
name: Linux
name: Unit tests
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand All @@ -28,3 +28,22 @@ jobs:
run: |
set -eux
pytest -vv -r ap --cov jupyter_ai
typing-tests:
name: Typing test
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Base Setup
uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1

- name: Install extension dependencies and build the extension
run: ./scripts/install.sh

- name: Run mypy
run: |
set -eux
mypy --version
mypy packages/jupyter-ai
25 changes: 23 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@

<!-- <START NEW CHANGELOG ENTRY> -->

## 2.23.0

([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.22.0...83cbd8ea240f1429766c417bada3bfb39afc4462))

### Enhancements made

- Allow unlimited LLM memory through traitlets configuration [#986](https://github.com/jupyterlab/jupyter-ai/pull/986) ([@krassowski](https://github.com/krassowski))
- Allow to disable automatic inline completions [#981](https://github.com/jupyterlab/jupyter-ai/pull/981) ([@krassowski](https://github.com/krassowski))
- Add ability to delete messages + start new chat session [#951](https://github.com/jupyterlab/jupyter-ai/pull/951) ([@michaelchia](https://github.com/michaelchia))

### Bugs fixed

- Fix `RunnableWithMessageHistory` import [#980](https://github.com/jupyterlab/jupyter-ai/pull/980) ([@krassowski](https://github.com/krassowski))
- Fix sort messages [#975](https://github.com/jupyterlab/jupyter-ai/pull/975) ([@michaelchia](https://github.com/michaelchia))

### Contributors to this release

([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-08-29&to=2024-09-11&type=c))

[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-08-29..2024-09-11&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-08-29..2024-09-11&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-08-29..2024-09-11&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-08-29..2024-09-11&type=Issues)

<!-- <END NEW CHANGELOG ENTRY> -->

## 2.22.0

([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.21.0...79158abf7044605c5776e205dd171fe87fb64142))
Expand All @@ -25,8 +48,6 @@

[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-08-19..2024-08-29&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-08-19..2024-08-29&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-08-19..2024-08-29&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-08-19..2024-08-29&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-08-19..2024-08-29&type=Issues) | [@trducng](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Atrducng+updated%3A2024-08-19..2024-08-29&type=Issues)

<!-- <END NEW CHANGELOG ENTRY> -->

## 2.21.0

([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.20.0...83e368b9d04904f9eb0ad4b1f0759bf3b7bbc93d))
Expand Down
2 changes: 1 addition & 1 deletion lerna.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "node_modules/lerna/schemas/lerna-schema.json",
"useWorkspaces": true,
"version": "2.22.0",
"version": "2.23.0",
"npmClient": "yarn",
"useNx": true
}
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@jupyter-ai/monorepo",
"version": "2.22.0",
"version": "2.23.0",
"description": "A generative AI extension for JupyterLab",
"private": true,
"keywords": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel):

class CompletionError(BaseModel):
type: str
title: str
traceback: str


Expand Down
Empty file.
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@jupyter-ai/magics",
"version": "2.22.0",
"version": "2.23.0",
"description": "Jupyter AI magics Python package. Not published on NPM.",
"private": true,
"homepage": "https://github.com/jupyterlab/jupyter-ai",
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-test/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@jupyter-ai/test",
"version": "2.22.0",
"version": "2.23.0",
"description": "Jupyter AI test package. Not published on NPM or PyPI.",
"private": true,
"homepage": "https://github.com/jupyterlab/jupyter-ai",
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage):

try:
with self.pending("Searching learned documents", message):
assert self.llm_chain
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
Expand Down
27 changes: 14 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Type,
Union,
cast,
)
from uuid import uuid4

Expand All @@ -28,6 +29,7 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
Expand All @@ -36,8 +38,8 @@
from langchain_core.chat_history import BaseChatMessageHistory


def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:
if preferred_dir != "":
def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]:
if preferred_dir is not None and preferred_dir != "":
preferred_dir = os.path.expanduser(preferred_dir)
if not preferred_dir.startswith(root_dir):
preferred_dir = os.path.join(root_dir, preferred_dir)
Expand All @@ -47,7 +49,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:

# Chat handler type, with specific attributes for each
class HandlerRoutingType(BaseModel):
routing_method: ClassVar[Union[Literal["slash_command"]]] = ...
routing_method: ClassVar[Union[Literal["slash_command"]]]
"""The routing method that sends commands to this handler."""


Expand Down Expand Up @@ -83,17 +85,17 @@ class BaseChatHandler:
multiple chat handler classes."""

# Class attributes
id: ClassVar[str] = ...
id: ClassVar[str]
"""ID for this chat handler; should be unique"""

name: ClassVar[str] = ...
name: ClassVar[str]
"""User-facing name of this handler"""

help: ClassVar[str] = ...
help: ClassVar[str]
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: ClassVar[HandlerRoutingType] = ...
routing_type: ClassVar[HandlerRoutingType]

uses_llm: ClassVar[bool] = True
"""Class attribute specifying whether this chat handler uses the LLM
Expand Down Expand Up @@ -153,9 +155,9 @@ def __init__(
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers

self.llm = None
self.llm_params = None
self.llm_chain = None
self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand All @@ -168,9 +170,8 @@ async def on_message(self, message: HumanChatMessage):

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
routing_type = cast(SlashCommandRoutingType, self.routing_type)
slash_command = "/" + routing_type.slash_id if routing_type.slash_id else ""
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def create_llm_chain(
prompt_template = llm.get_chat_prompt_template()
self.llm = llm

runnable = prompt_template | llm
runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable,
runnable=runnable, # type:ignore[arg-type]
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
Expand Down Expand Up @@ -106,6 +106,7 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"last_human_msg": message}},
Expand All @@ -117,7 +118,7 @@ async def process_message(self, message: HumanChatMessage):
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage):

self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler):
def __init__(self, log_dir: Optional[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_dir = Path(log_dir) if log_dir else None
self.llm = None
self.llm: Optional[BaseProvider] = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str):
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

assert self.llm
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ async def learn_dir(
}
splitter = ExtensionSplitter(
splitters=splitters,
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
default_splitter=RecursiveCharacterTextSplitter(
**splitter_kwargs # type:ignore[arg-type]
),
)

delayed = split(path, all_files, splitter=splitter)
Expand Down Expand Up @@ -352,7 +354,7 @@ async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
if not self.index:
return []
return [] # type:ignore[return-value]

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
Expand All @@ -370,12 +372,14 @@ def get_embedding_model(self):


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None
learn_chat_handler: LearnChatHandler = None # type:ignore[assignment]

def _get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents( # type:ignore[override]
self, query: str
) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
async def _aget_relevant_documents( # type:ignore[override]
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
Expand Down
8 changes: 4 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand Down Expand Up @@ -60,7 +60,7 @@ class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
return
Expand Down Expand Up @@ -147,7 +147,7 @@ def _init_config_schema(self):
os.makedirs(os.path.dirname(self.schema_path), exist_ok=True)
shutil.copy(OUR_SCHEMA_PATH, self.schema_path)

def _init_validator(self) -> Validator:
def _init_validator(self) -> None:
with open(OUR_SCHEMA_PATH, encoding="utf-8") as f:
schema = json.loads(f.read())
Validator.check_schema(schema)
Expand Down Expand Up @@ -364,7 +364,7 @@ def delete_api_key(self, key_name: str):
config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def update_config(self, config_update: UpdateConfigRequest):
def update_config(self, config_update: UpdateConfigRequest): # type:ignore
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
raise WriteConflictError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str:
output path to the downloaded TeX file
"""

import arxiv
import arxiv # type:ignore[import-not-found,import-untyped]

outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex"
download_filename = "downloaded-paper.tar.gz"
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
handlers = [ # type:ignore[assignment]
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
Expand Down
Loading

0 comments on commit f7be16e

Please sign in to comment.