Skip to content

Commit

Permalink
lazy load agate
Browse files Browse the repository at this point in the history
  • Loading branch information
dwreeves committed Mar 31, 2024
1 parent 0c44464 commit 8afdc79
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
18 changes: 12 additions & 6 deletions dbt/adapters/clickhouse/connections.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import re
import time
from contextlib import contextmanager
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union, TYPE_CHECKING

import agate
import dbt.exceptions
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection

from dbt.adapters.clickhouse.dbclient import ChRetryableException, get_db_client
from dbt.adapters.clickhouse.logger import logger

if TYPE_CHECKING:
import agate

retryable_exceptions = [ChRetryableException]
ddl_re = re.compile(r'^\s*(CREATE|DROP|ALTER)\s', re.IGNORECASE)

Expand Down Expand Up @@ -60,21 +62,23 @@ def release(self):
pass # There is no "release" type functionality in the existing ClickHouse connectors

@classmethod
def get_table_from_response(cls, response, column_names) -> agate.Table:
def get_table_from_response(cls, response, column_names) -> "agate.Table":
"""
Build agate table from response.
:param response: ClickHouse query result
:param column_names: Table column names
"""
from dbt.clients.agate_helper import table_from_data_flat

data = []
for row in response:
data.append(dict(zip(column_names, row)))

return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
return table_from_data_flat(data, column_names)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
) -> Tuple[AdapterResponse, "agate.Table"]:
# Don't try to fetch result of clustered DDL responses, we don't know what to do with them
if fetch and ddl_re.match(sql):
fetch = False
Expand All @@ -97,7 +101,9 @@ def execute(
query_result.result_set, query_result.column_names
)
else:
table = dbt.clients.agate_helper.empty_table()
from dbt.clients.agate_helper import empty_table

table = empty_table()
return AdapterResponse(_message=status), table

def add_query(
Expand Down
34 changes: 20 additions & 14 deletions dbt/adapters/clickhouse/impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import csv
import io
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING

import agate
from dbt.adapters.base import AdapterConfig, available
from dbt.adapters.base.impl import BaseAdapter, ConstraintSupport
from dbt.adapters.base.relation import BaseRelation, InformationSchema
Expand Down Expand Up @@ -31,6 +30,9 @@
from dbt.adapters.clickhouse.relation import ClickHouseRelation, ClickHouseRelationType
from dbt.adapters.clickhouse.util import NewColumnDataType, compare_versions

if TYPE_CHECKING:
import agate

GET_CATALOG_MACRO_NAME = 'get_catalog'
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'

Expand Down Expand Up @@ -73,29 +75,31 @@ def date_function(cls):
return 'now()'

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return 'String'

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
import agate

decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
# We match these type to the Column.TYPE_LABELS for consistency
return 'Float32' if decimals else 'Int32'

@classmethod
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_boolean_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return 'Bool'

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return 'DateTime'

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return 'Date'

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
raise NotImplementedError('`convert_time_type` is not implemented for this adapter!')

@available.parse(lambda *a, **k: {})
Expand Down Expand Up @@ -308,13 +312,15 @@ def get_ch_database(self, schema: str):
except DbtRuntimeError:
return None

def get_catalog(self, manifest) -> Tuple[agate.Table, List[Exception]]:
def get_catalog(self, manifest) -> Tuple["agate.Table", List[Exception]]:
from dbt.clients.agate_helper import empty_table

relations = self._get_catalog_relations(manifest)
schemas = set(relation.schema for relation in relations)
if schemas:
catalog = self._get_one_catalog(InformationSchema(Path()), schemas, manifest)
else:
catalog = dbt.clients.agate_helper.empty_table()
catalog = empty_table()
return catalog, []

def get_filtered_catalog(
Expand All @@ -324,7 +330,7 @@ def get_filtered_catalog(
if relations and catalog:
relation_map = {(r.schema, r.identifier) for r in relations}

def in_map(row: agate.Row):
def in_map(row: "agate.Row"):
s = _expect_row_value("table_schema", row)
i = _expect_row_value("table_name", row)
return (s, i) in relation_map
Expand Down Expand Up @@ -488,17 +494,17 @@ class ClickHouseDatabase:
comment: str


def _expect_row_value(key: str, row: agate.Row):
def _expect_row_value(key: str, row: "agate.Row"):
if key not in row.keys():
raise DbtInternalError(f'Got a row without \'{key}\' column, columns: {row.keys()}')

return row[key]


def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
def _catalog_filter_schemas(manifest: Manifest) -> Callable[["agate.Row"], bool]:
schemas = frozenset((None, s) for d, s in manifest.get_used_schemas())

def test(row: agate.Row) -> bool:
def test(row: "agate.Row") -> bool:
table_database = _expect_row_value('table_database', row)
table_schema = _expect_row_value('table_schema', row)
if table_schema is None:
Expand Down

0 comments on commit 8afdc79

Please sign in to comment.