Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Sep 10, 2024
1 parent 68ad831 commit c36dc1e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
1 change: 0 additions & 1 deletion pysqlsync/dialect/mssql/object_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import typing
from typing import Optional

from pysqlsync.formation.object_types import (
Expand Down
20 changes: 18 additions & 2 deletions pysqlsync/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import logging
import re
import typing
from urllib.parse import unquote, urlparse
from typing import Optional
from urllib.parse import parse_qs, unquote, urlparse

from strong_typing.inspection import get_module_classes

from .base import BaseConnection, BaseEngine, BaseGenerator, Explorer
from .connection import ConnectionParameters
from .connection import ConnectionParameters, ConnectionSSLMode

LOGGER = logging.getLogger("pysqlsync")

Expand Down Expand Up @@ -80,12 +81,27 @@ def get_parameters(url: str) -> tuple[str, ConnectionParameters]:
"""

parts = urlparse(url, allow_fragments=False)
query = parse_qs(parts.query, strict_parsing=True)
ssl: Optional[ConnectionSSLMode] = None
if "ssl" in query:
if len(query["ssl"]) != 1:
raise ValueError(
"only a single `ssl` parameter is permitted in a connection string"
)
ssl_mode = query["ssl"][0]
for v in ConnectionSSLMode.__members__.values():
if ssl_mode == v.value:
ssl = v
break
else:
raise ValueError(f"unsupported SSL mode: {ssl_mode}")
return parts.scheme, ConnectionParameters(
host=parts.hostname,
port=parts.port,
username=unquote(parts.username) if parts.username else None,
password=unquote(parts.password) if parts.password else None,
database=parts.path.lstrip("/") if parts.path else None,
ssl=ssl,
)


Expand Down
27 changes: 25 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from urllib.parse import quote

from pysqlsync.connection import ConnectionParameters
from pysqlsync.connection import ConnectionParameters, ConnectionSSLMode
from pysqlsync.factory import get_dialect, get_parameters


Expand All @@ -14,7 +14,7 @@ def test_dialect_unavailable(self) -> None:
with self.assertRaises(RuntimeError):
engine.create_generator()

def test_connection_parameters(self) -> None:
def test_connection_string(self) -> None:
host = "server.example.com"
port = 2310
username = "my+user@example.com"
Expand All @@ -28,10 +28,33 @@ def test_connection_parameters(self) -> None:
self.assertEqual(params.username, username)
self.assertEqual(params.password, password)
self.assertEqual(params.database, database)
self.assertEqual(params.ssl, None)
self.assertEqual(
str(params), r"my%2Buser%40example.com@server.example.com:2310/database"
)

def test_connection_query_parameters(self) -> None:
host = "server.example.com"
port = 2310
database = "database"
url_prefix = f"postgresql://{host}:{port}/{database}"

url = f"{url_prefix}?key=value"
dialect, params = get_parameters(url)
self.assertEqual(dialect, "postgresql")
self.assertEqual(params.host, host)
self.assertEqual(params.port, port)
self.assertEqual(params.database, database)
self.assertEqual(params.ssl, None)

url = f"{url_prefix}?ssl=verify-full"
dialect, params = get_parameters(url)
self.assertEqual(dialect, "postgresql")
self.assertEqual(params.host, host)
self.assertEqual(params.port, port)
self.assertEqual(params.database, database)
self.assertEqual(params.ssl, ConnectionSSLMode.verify_full)


if __name__ == "__main__":
unittest.main()

0 comments on commit c36dc1e

Please sign in to comment.