Skip to content

Commit

Permalink
feat: add jupyter api runtime (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
opqrstuvcut authored Oct 7, 2024
1 parent f2e9ba9 commit 047bf41
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ https://github.com/benlubas/molten-nvim/assets/56943754/17ae81c0-306f-4496-bce8-
- `pyperclip` if you want to use `molten_copy_output`
- `nbformat` for importing and exporting output to jupyter notebooks files
- `pillow` for opening images with `:MoltenImagePopup`
- `requests` and `websocket-client` for connecting to the Jupyter Server API via HTTP and WebSocket with `:MoltenInit <Jupyter server URL>`

You can run `:checkhealth` to see what you have installed.

Expand Down
136 changes: 136 additions & 0 deletions rplugin/python3/molten/jupyter_server_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
import time
import uuid
from queue import Empty as EmptyQueueException
from queue import Queue
from threading import Thread
from typing import Any, Dict
from urllib.parse import parse_qs, urlparse

from molten.runtime_state import RuntimeState


class JupyterAPIClient:
def __init__(self,
url: str,
kernel_info: Dict[str, Any],
headers: Dict[str, str]):
self._base_url = url
self._kernel_info = kernel_info
self._headers = headers

self._recv_queue: Queue[Dict[str, Any]] = Queue()

import requests
self.requests = requests

def get_stdin_msg(self, **kwargs):
return None

def wait_for_ready(self, timeout: float = 0.):
start = time.time()
while True:
response = self.requests.get(self._kernel_api_base,
headers=self._headers)
response = json.loads(response.text)

if response["execution_state"] != "idle" and time.time() - start > timeout:
raise RuntimeError

# Discard unnecessary messages.
while True:
try:
response = self.get_iopub_msg()
except EmptyQueueException:
return


def start_channels(self) -> None:
import websocket

parsed_url = urlparse(self._base_url)
self._socket = websocket.create_connection(f"ws://{parsed_url.hostname}:{parsed_url.port}"
f"/api/kernels/{self._kernel_info['id']}/channels",
header=self._headers,
)
self._kernel_api_base = f"{self._base_url}/api/kernels/{self._kernel_info['id']}"

self._iopub_recv_thread = Thread(target=self._recv_message)
self._iopub_recv_thread.start()

def _recv_message(self) -> None:
while True:
response = json.loads(self._socket.recv())
self._recv_queue.put(response)

def get_iopub_msg(self, **kwargs):
if self._recv_queue.empty():
raise EmptyQueueException

response = self._recv_queue.get()

return response

def execute(self, code: str):
header = {
'msg_type': 'execute_request',
'msg_id': uuid.uuid1().hex,
'session': uuid.uuid1().hex
}

message = json.dumps({
'header': header,
'parent_header': header,
'metadata': {},
'content': {
'code': code,
'silent': False
}
})
self._socket.send(message)

def shutdown(self):
self.requests.delete(self._kernel_api_base,
headers=self._headers)

def cleanup_connection_file(self):
pass

class JupyterAPIManager:
def __init__(self,
url: str,
):
parsed_url = urlparse(url)
self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

token = parse_qs(parsed_url.query).get("token")
if token:
self._headers = {'Authorization': f'token {token[0]}'}
else:
# Run notebook with --NotebookApp.disable_check_xsrf="True".
self._headers = {}

import requests
self.requests = requests

def start_kernel(self) -> None:
url = f"{self._base_url}/api/kernels"
response = self.requests.post(url,
headers=self._headers)
self._kernel_info = json.loads(response.text)
assert "id" in self._kernel_info, "Could not connect to Jupyter Server API. The URL specified may be incorrect."
self._kernel_api_base = f"{url}/{self._kernel_info['id']}"

def client(self) -> JupyterAPIClient:
return JupyterAPIClient(url=self._base_url,
kernel_info=self._kernel_info,
headers=self._headers)

def interrupt_kernel(self) -> None:
self.requests.post(f"{self._kernel_api_base}/interrupt",
headers=self._headers)

def restart_kernel(self) -> None:
self.state = RuntimeState.STARTING
self.requests.post(f"{self._kernel_api_base}/restart",
headers=self._headers)
33 changes: 21 additions & 12 deletions rplugin/python3/molten/runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from typing import Optional, Tuple, List, Dict, Generator, IO, Any
from enum import Enum
from contextlib import contextmanager
from queue import Empty as EmptyQueueException
import os
Expand All @@ -20,21 +19,17 @@
to_outputchunk,
clean_up_text,
)


class RuntimeState(Enum):
STARTING = 0
IDLE = 1
RUNNING = 2
from molten.runtime_state import RuntimeState
from molten.jupyter_server_api import JupyterAPIClient, JupyterAPIManager


class JupyterRuntime:
state: RuntimeState
kernel_name: str
kernel_id: str

kernel_manager: jupyter_client.KernelManager # type: ignore
kernel_client: jupyter_client.KernelClient # type: ignore
kernel_manager: jupyter_client.KernelManager | JupyterAPIManager # type: ignore
kernel_client: jupyter_client.KernelClient | JupyterAPIClient # type: ignore

allocated_files: List[str]

Expand All @@ -48,7 +43,14 @@ def __init__(self, nvim: Nvim, kernel_name: str, kernel_id: str, options: Molten
self.nvim = nvim
self.nvim.exec_lua("_prompt_stdin = require('prompt').prompt_stdin")

if ".json" not in self.kernel_name:
if kernel_name.startswith("http://") or kernel_name.startswith("https://"):
self.external_kernel = False
self.kernel_manager = JupyterAPIManager(kernel_name)
self.kernel_manager.start_kernel()
self.kernel_client = self.kernel_manager.client()
self.kernel_client.start_channels()
self.options = options
elif ".json" not in self.kernel_name:
self.external_kernel = False
self.kernel_manager = jupyter_client.manager.KernelManager(kernel_name=kernel_name)
self.kernel_manager.start_kernel()
Expand Down Expand Up @@ -204,7 +206,10 @@ def tick(self, output: Optional[Output]) -> bool:

assert isinstance(
self.kernel_client,
jupyter_client.blocking.client.BlockingKernelClient,
(
jupyter_client.blocking.client.BlockingKernelClient,
JupyterAPIClient,
),
)

if not self.is_ready():
Expand Down Expand Up @@ -240,7 +245,11 @@ def tick_input(self):
if not self.is_ready:
return

assert isinstance(self.kernel_client, jupyter_client.blocking.client.BlockingKernelClient)
assert isinstance(
self.kernel_client,
(jupyter_client.blocking.client.BlockingKernelClient,
JupyterAPIClient),
)

try:
msg = self.kernel_client.get_stdin_msg(timeout=0)
Expand Down
7 changes: 7 additions & 0 deletions rplugin/python3/molten/runtime_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class RuntimeState(Enum):
STARTING = 0
IDLE = 1
RUNNING = 2

0 comments on commit 047bf41

Please sign in to comment.