Skip to content

Commit

Permalink
Improve raw connector (#44)
Browse files Browse the repository at this point in the history
* Improve raw connector

* incorrect fix mypy

* improve README.rst

* Set port in test

* Default port
  • Loading branch information
aamalev authored Aug 1, 2023
1 parent c8f3911 commit bc04cbf
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 52 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ _____________________
pool:
min_size: 1
max_size: 100
connection: # optional
init: mymodule.connection_init
setup: mymodule.connection_setup
class: mymodule.Connection
Storage
Expand Down
5 changes: 4 additions & 1 deletion aioworkers_pg/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class AbstractPGConnector(AbstractConnector):
_dsn: Optional[URI] = None
_default_dsn: URI = URI("postgresql:///")
_default_port: int = 5432

def set_config(self, config: ValueExtractor) -> None:
cfg: ValueExtractor = config.new_parent(logger=__package__)
Expand All @@ -20,7 +21,9 @@ def set_config(self, config: ValueExtractor) -> None:
dsn = (dsn or self._default_dsn).with_host(host)

port = self.config.get_int("port", null=True)
if port:
if port in {self._default_port}:
port = 0
if port is not None:
dsn = (dsn or self._default_dsn).with_port(port)

username = self.config.get("username")
Expand Down
86 changes: 37 additions & 49 deletions aioworkers_pg/base.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,51 @@
import warnings
from typing import Awaitable, Callable, Optional, Type
from typing import Optional

import asyncpg
from aioworkers.core.base import AbstractConnector
from aioworkers.core.config import ValueExtractor

from aioworkers_pg.abc import AbstractPGConnector

class Connector(AbstractConnector):
_pool: Optional[asyncpg.pool.Pool]
_pool_init: Optional[Callable[[asyncpg.Connection], Awaitable]] = None
_pool_setup: Optional[Callable[[asyncpg.Connection], Awaitable]] = None
_connection_class: Optional[Type[asyncpg.connection.Connection]] = asyncpg.connection.Connection
_record_class: type = asyncpg.protocol.Record
_connect_kwargs: dict = {}

class Connector(AbstractPGConnector):
_connect_kwargs: dict
_pool: Optional[asyncpg.pool.Pool] = None

def __init__(self, *args, **kwargs):
self._pool = None
self._pool_init = self._default_pool_init
self._connect_kwargs = {
"init": self._default_connection_init,
}
super().__init__(*args, **kwargs)

def set_config(self, config: ValueExtractor) -> None:
cfg: ValueExtractor = config.new_parent(logger=__package__)
super().set_config(cfg)

# TODO: Remove deprecated code.
if self.config.get("connection.init"):
warnings.warn("Do not use connection.init config. Use pool.init", DeprecationWarning)
self._pool_init = self.context.get_object(self.config.get("connection.init"))
connection_init: Optional[str] = self.config.get("connection.init")
connection_setup: Optional[str] = self.config.get("connection.setup")
connection_class: Optional[str] = self.config.get("connection.class")

pool_config = dict(self.config.get("pool", {}))
if not pool_config:
# Do not process pool config if there is no any parameter
return
kwargs = dict(self.config.get("pool") or ())

pool_init: Optional[str] = pool_config.pop("init", None)
if pool_init:
self._pool_init = self.context.get_object(pool_init)
connection_init = kwargs.pop("init", connection_init)
if connection_init:
kwargs["init"] = self.context.get_object(connection_init)
else:
kwargs["init"] = self._default_connection_init

pool_setup: Optional[str] = pool_config.pop("setup", None)
if pool_setup:
self._pool_setup = self.context.get_object(pool_setup)
connection_setup = kwargs.pop("setup", connection_setup)
if connection_setup:
kwargs["setup"] = self.context.get_object(connection_setup)

connection_class: Optional[str] = pool_config.pop("connection_class", None)
connection_class = kwargs.pop("connection_class", connection_class)
if connection_class:
self._connection_class = self.context.get_object(connection_class)
kwargs["connection_class"] = self.context.get_object(connection_class)

record_class: Optional[str] = pool_config.pop("record_class", None)
record_class: Optional[str] = kwargs.pop("record_class", None)
if record_class:
self._record_class = self.context.get_object(record_class)
kwargs["record_class"] = self.context.get_object(record_class)

self._connect_kwargs = pool_config
self._connect_kwargs = kwargs

@property
def pool(self) -> asyncpg.pool.Pool:
Expand All @@ -61,28 +57,20 @@ def __getattr__(self, attr):
return getattr(self._pool, attr)

async def connect(self):
if self._pool is None:
self._pool = await self.pool_factory(self.config)

async def pool_factory(self, config) -> asyncpg.pool.Pool:
pool = await asyncpg.create_pool(
config.dsn,
init=self._pool_init,
setup=self._pool_setup,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
self.logger.debug("Create pool with address %s", config.dsn)
return pool
assert self._pool is None, "Already created pool"
dsn = self._dsn
self._pool = await asyncpg.create_pool(dsn, **self._connect_kwargs)
if dsn:
dsn = dsn.with_password("***")
self.logger.debug("Create pool with address %s", dsn)

async def disconnect(self):
if self._pool is not None:
self.logger.debug("Close pool")
await self._pool.close()
self._pool = None
assert self._pool is not None, "Pool not created"
self.logger.debug("Close pool")
await self._pool.close()
self._pool = None

async def _default_pool_init(self, connection: asyncpg.Connection):
async def _default_connection_init(self, connection: asyncpg.Connection):
import json

# TODO: Need general solution to add codecs
Expand Down
21 changes: 19 additions & 2 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import pytest
from aioworkers.net.uri import URI

from aioworkers_pg.base import Connector


@pytest.fixture
def config(config, dsn):
uri = URI(dsn)
port = uri.port or 5432 # type: ignore # https://github.com/aioworkers/aioworkers/pull/202
database = (uri.path or "").strip("/") # type: ignore
config.update(
db={
"name": "db",
"cls": "aioworkers_pg.base.Connector",
"dsn": dsn,
"username": uri.username,
"password": uri.password,
"host": uri.hostname,
"database": database,
"port": port,
},
)
return config


async def test_connector(context):
def test_dsn(config, dsn):
c = Connector()
c.set_config(config.db)
assert str(c._dsn) == dsn


async def test_connector(context, dsn):
"""
Test common asyncpg pool methods which were bind to connector.
"""
Expand Down

0 comments on commit bc04cbf

Please sign in to comment.