Skip to content

Commit

Permalink
Add SQLite client and converter
Browse files Browse the repository at this point in the history
Recap can now read SQLite schemas as Recap types. SQLite's schema system is
somewhat strange. Some notes:

1. Any column can store any type.
2. SQLite has 5 storage classes (null, int, real, text, blob).
3. STRICT forces column types to be the storage types.
4. non-STRICT tables allow any strings for column types.
5. non-STRICT column types are hints to coerce types as they're written to disk.
6. Parenthesis in types (e.g. DOUBLE(6, 2)) are ignored by SQLite.

See https://www.sqlite.org/datatype3.html#storage_classes_and_datatypes for more
details.

With all of this in mind, Recap's SQLiteConverter works according to SQLite's
affinity rules. This means:

1. Unknown types are treated as "ANY", which is a union of all storage types.
2. SQLiteConverter pays attention to precision/scale for REAL, etc.
3. SQLiteConverter pays attention to lengths for VARCHAR(255), etc.
4. SQLiteConverter treats date, datetime, time, and timestamp as ANY types.

Closes #418
  • Loading branch information
criccomini committed Feb 5, 2024
1 parent 0ea028b commit 6ae133f
Show file tree
Hide file tree
Showing 10 changed files with 735 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
- name: Test with pytest
env:
RECAP_URLS: '["postgresql://postgres:password@localhost:5432/testdb"]'
RECAP_URLS: '["postgresql://postgres:password@localhost:5432/testdb", "sqlite:///file:mem1?mode=memory&cache=shared&uri=true"]'
run: |
pdm run integration
Expand Down
1 change: 1 addition & 0 deletions recap/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"mysql": "recap.clients.mysql.MysqlClient",
"postgresql": "recap.clients.postgresql.PostgresqlClient",
"snowflake": "recap.clients.snowflake.SnowflakeClient",
"sqlite": "recap.clients.sqlite.SQLiteClient",
"thrift+hms": "recap.clients.hive_metastore.HiveMetastoreClient",
}

Expand Down
146 changes: 146 additions & 0 deletions recap/clients/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from contextlib import contextmanager
from re import compile as re_compile
from typing import Any, Generator

from recap.clients.dbapi import Connection
from recap.converters.sqlite import SQLiteAffinity, SQLiteConverter
from recap.types import StructType

SQLITE3_CONNECT_ARGS = {
"database",
"timeout",
"detect_types",
"isolation_level",
"check_same_thread",
"factory",
"cached_statements",
"uri",
}


class SQLiteClient:
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.converter = SQLiteConverter()

@staticmethod
@contextmanager
def create(url: str, **url_args) -> Generator[SQLiteClient, None, None]:
import sqlite3

# Strip sqlite:/// URL prefix
url_args["database"] = url[len("sqlite:///") :]

# Only include kwargs that are valid for PsycoPG2 parse_dsn()
url_args = {k: v for k, v in url_args.items() if k in SQLITE3_CONNECT_ARGS}

with sqlite3.connect(**url_args) as client:
yield SQLiteClient(client) # type: ignore

@staticmethod
def parse(method: str, **url_args) -> tuple[str, list[Any]]:
from urllib.parse import urlunparse

match method:
case "ls":
return (url_args["url"], [])
case "schema":
table = url_args["paths"].pop(-1)
connection_url = urlunparse(
[
url_args.get("dialect") or url_args.get("scheme"),
url_args.get("netloc"),
# Include / prefix for paths
"/".join(url_args.get("paths", [])),
url_args.get("params"),
url_args.get("query"),
url_args.get("fragment"),
]
)

# urlunsplit does not double slashes if netloc is empty. But most
# URLs with empty netloc should have a double slash (e.g.
# bigquery:// or sqlite:///some/file.db). Include an extra "/"
# because the root path is not included with an empty netloc
# and join().
if not url_args.get("netloc"):
connection_url = connection_url.replace(":", ":///", 1)

return (connection_url, [table])
case _:
raise ValueError("Invalid method")

def ls(self) -> list[str]:
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_schema WHERE type='table'")
return [row[0] for row in cursor.fetchall()]

def schema(self, table: str) -> StructType:
cursor = self.connection.cursor()

# Validate that table exists since we want to prevent SQL injections in
# the PRAGMA call
if not self._table_exists(table):
raise ValueError(f"Table '{table}' does not exist in the database.")

cursor.execute(f"PRAGMA table_info({table});")
names = [name[0].upper() for name in cursor.description]
rows = []

for row_cells in cursor.fetchall():
row = dict(zip(names, row_cells))
row = self.add_information_schema(row)
row = self.add_information_schema(row)
rows.append(row)

return self.converter.to_recap(rows)

def add_information_schema(self, row: dict[str, Any]) -> dict[str, Any]:
"""
SQLite does not have an INFORMATION_SCHEMA, so we need to add these
columns.
:param row: A row from the PRAGMA table_info() query.
:return: The row with the INFORMATION_SCHEMA columns added.
"""

is_not_null = row["NOTNULL"] or row["PK"]

# Set defaults.
information_schema_cols = {
"COLUMN_NAME": row["NAME"],
"IS_NULLABLE": "NO" if is_not_null else "YES",
"COLUMN_DEFAULT": row["DFLT_VALUE"],
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"CHARACTER_OCTET_LENGTH": None,
}

# Extract precision, scale, and octet length.
numeric_pattern = re_compile(r"(\w+)\((\d+)(?:,\s*(\d+))?\)")
param_match = numeric_pattern.search(row["TYPE"])

if param_match:
# Extract matched values
base_type, precision, scale = param_match.groups()
base_type = base_type.upper()
precision = int(precision)
scale = int(scale) if scale else 0

match SQLiteConverter.get_affinity(base_type):
case SQLiteAffinity.INTEGER | SQLiteAffinity.REAL | SQLiteAffinity.NUMERIC:
information_schema_cols["NUMERIC_PRECISION"] = precision
information_schema_cols["NUMERIC_SCALE"] = scale
case SQLiteAffinity.TEXT | SQLiteAffinity.BLOB:
information_schema_cols["CHARACTER_OCTET_LENGTH"] = precision

return row | information_schema_cols

def _table_exists(self, table: str) -> bool:
cursor = self.connection.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
)
return bool(cursor.fetchone())
90 changes: 90 additions & 0 deletions recap/converters/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from enum import Enum
from typing import Any

from recap.converters.dbapi import DbapiConverter
from recap.types import (
BytesType,
FloatType,
IntType,
NullType,
RecapType,
StringType,
UnionType,
)

# SQLite's maximum length is 2^31-1 bytes, or 2147483647 bytes.
SQLITE_MAX_LENGTH = 2147483647


class SQLiteAffinity(Enum):
"""
SQLite uses column affinity to map non-STRICT table columns to values. See
https://www.sqlite.org/datatype3.html#type_affinity for details.
"""

INTEGER = "integer"
REAL = "real"
TEXT = "text"
BLOB = "blob"
NUMERIC = "numeric"


class SQLiteConverter(DbapiConverter):
def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
column_name = column_props["COLUMN_NAME"]
column_type = column_props["TYPE"]
octet_length = column_props["CHARACTER_OCTET_LENGTH"]
precision = column_props["NUMERIC_PRECISION"]

match SQLiteConverter.get_affinity(column_type):
case SQLiteAffinity.INTEGER:
return IntType(bits=64)
case SQLiteAffinity.REAL:
if precision and precision <= 23:
return FloatType(bits=32)
return FloatType(bits=64)
case SQLiteAffinity.TEXT:
return StringType(bytes_=octet_length or SQLITE_MAX_LENGTH)
case SQLiteAffinity.BLOB:
return BytesType(bytes_=octet_length or SQLITE_MAX_LENGTH)
case SQLiteAffinity.NUMERIC:
# NUMERIC affinity may contain values using all five storage classes
return UnionType(
types=[
NullType(),
IntType(bits=64),
FloatType(bits=64),
StringType(bytes_=SQLITE_MAX_LENGTH),
BytesType(bytes_=SQLITE_MAX_LENGTH),
]
)
case _:
raise ValueError(
f"Unsupported `{column_type}` type for `{column_name}`"
)

@staticmethod
def get_affinity(column_type: str | None) -> SQLiteAffinity:
"""
Encode affinity rules as defined here:
https://www.sqlite.org/datatype3.html#determination_of_column_affinity
:param column_type: The column type to determine the affinity of.
:return: The affinity of the column type.
"""

column_type = (column_type or "").upper()

if not column_type:
return SQLiteAffinity.BLOB
elif "INT" in column_type:
return SQLiteAffinity.INTEGER
elif "CHAR" in column_type or "TEXT" in column_type or "CLOB" in column_type:
return SQLiteAffinity.TEXT
elif "BLOB" in column_type:
return SQLiteAffinity.BLOB
elif "REAL" in column_type or "FLOA" in column_type or "DOUB" in column_type:
return SQLiteAffinity.REAL
else:
return SQLiteAffinity.NUMERIC
2 changes: 1 addition & 1 deletion tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ services:
retries: 5

hive-metastore:
image: ghcr.io/criccomini/hive-metastore-standalone:latest
image: ghcr.io/recap-build/hive-metastore-standalone:latest
ports:
- "9083:9083"
5 changes: 4 additions & 1 deletion tests/integration/server/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def teardown_class(cls):
def test_ls_root(self):
response = client.get("/ls")
assert response.status_code == 200
assert response.json() == ["postgresql://localhost:5432/testdb"]
assert response.json() == [
"postgresql://localhost:5432/testdb",
"sqlite:///file:mem1?mode=memory&cache=shared&uri=true",
]

def test_ls_subpath(self):
response = client.get("/ls/postgresql://localhost:5432/testdb")
Expand Down
83 changes: 59 additions & 24 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sqlite3
from json import loads

import psycopg2
Expand All @@ -13,33 +14,56 @@ class TestCli:
@classmethod
def setup_class(cls):
# Connect to the PostgreSQL database
cls.connection = psycopg2.connect(
cls.postgresql_connection = psycopg2.connect(
host="localhost",
port="5432",
user="postgres",
password="password",
dbname="testdb",
)

# Create tables
cursor = cls.connection.cursor()
# Create PostgreSQL tables
cursor = cls.postgresql_connection.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS test_types (test_integer INTEGER);")
cls.connection.commit()
cls.postgresql_connection.commit()

# Create a temporary SQLite database
cls.sqlite_connection = sqlite3.connect(
"file:mem1?mode=memory&cache=shared",
uri=True,
)
cursor = cls.sqlite_connection.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS test_sqlite_types (
test_integer INTEGER
);
"""
)
cls.sqlite_connection.commit()

@classmethod
def teardown_class(cls):
# Delete tables
cursor = cls.connection.cursor()
cursor = cls.postgresql_connection.cursor()
cursor.execute("DROP TABLE IF EXISTS test_types;")
cls.connection.commit()
cls.postgresql_connection.commit()

# Close the connection
cls.connection.close()
# Close the connections
cls.postgresql_connection.close()
cls.sqlite_connection.close()

@pytest.mark.parametrize(
"cmd, url, expected",
[
["ls", "", ["postgresql://localhost:5432/testdb"]],
[
"ls",
"",
[
"postgresql://localhost:5432/testdb",
"sqlite:///file:mem1?mode=memory&cache=shared&uri=true",
],
],
[
"ls",
"postgresql://postgres:password@localhost:5432",
Expand All @@ -66,26 +90,37 @@ def teardown_class(cls):
],
],
["ls", "postgresql://localhost:5432/testdb/public", ["test_types"]],
],
)
def test_ls(self, cmd, url, expected):
result = runner.invoke(app, [cmd, url])
assert result.exit_code == 0
assert loads(result.stdout) == expected

def test_schema(self):
result = runner.invoke(
app,
[
"ls",
"sqlite:///file:mem1?mode=memory&cache=shared&uri=true",
["test_sqlite_types"],
],
[
"schema",
"postgresql://localhost:5432/testdb/public/test_types",
{
"type": "struct",
"fields": [
{"type": "int32", "name": "test_integer", "optional": True}
],
},
],
)
[
"schema",
"sqlite:///file:mem1/test_sqlite_types?mode=memory&cache=shared&uri=true",
{
"type": "struct",
"fields": [
{"type": "int64", "name": "test_integer", "optional": True}
],
},
],
],
)
def test_cmds(self, cmd, url, expected):
result = runner.invoke(app, [cmd, url])
assert result.exit_code == 0
assert loads(result.stdout) == {
"type": "struct",
"fields": [{"type": "int32", "name": "test_integer", "optional": True}],
}
assert loads(result.stdout) == expected

def test_schema_avro(self):
result = runner.invoke(
Expand Down
Loading

0 comments on commit 6ae133f

Please sign in to comment.