Skip to content

Commit

Permalink
Revert speculative cast (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Aug 23, 2024
1 parent 42700fc commit eaa1cff
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 49 deletions.
16 changes: 0 additions & 16 deletions lumen/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,22 +716,6 @@ class BaseSQLSource(Source):
# Declare this source supports SQL transforms
_supports_sql = True

def _compute_subset_schema(self, data):
casts = {}
for col in data.columns:
if not hasattr(data[col], 'str'):
continue
if data[col].str.isdigit().all():
data[col] = data[col].astype(int)
casts[col] = 'INT'
continue
try:
data[col] = data[col].astype(float)
casts[col] = 'FLOAT'
except ValueError:
continue
return get_dataframe_schema(data)['items']['properties'], casts

def get_sql_expr(self, table: str):
"""
Returns the SQL expression corresponding to a particular table.
Expand Down
16 changes: 7 additions & 9 deletions lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ..serializers import Serializer
from ..transforms import Filter
from ..transforms.sql import (
SQLCast, SQLDistinct, SQLFilter, SQLLimit, SQLMinMax,
SQLDistinct, SQLFilter, SQLLimit, SQLMinMax,
)
from ..util import get_dataframe_schema
from .base import (
BaseSQLSource, Source, cached, cached_schema,
)
Expand Down Expand Up @@ -271,24 +272,23 @@ def get_schema(
tables = [table]

schemas = {}
sql_limit = SQLLimit(limit=limit or 3)
sql_limit = SQLLimit(limit=limit or 1)
for entry in tables:
if not self.load_schema:
schemas[entry] = {}
continue
sql_expr = self.get_sql_expr(entry)
data = self._connection.execute(sql_limit.apply(sql_expr)).fetch_df()
schema, casts = self._compute_subset_schema(data)
schemas[entry] = schema
schemas[entry] = schema = get_dataframe_schema(data)['items']['properties']
if limit:
continue

enums, min_maxes = [], []
for name, col_schema in schema.items():
if 'inclusiveMinimum' in col_schema or name in casts:
min_maxes.append(name)
elif 'enum' in col_schema:
if 'enum' in col_schema:
enums.append(name)
elif 'inclusiveMinimum' in col_schema:
min_maxes.append(name)
for col in enums:
distinct_expr = SQLDistinct(columns=[col]).apply(sql_expr)
distinct_expr = ' '.join(distinct_expr.splitlines())
Expand All @@ -298,8 +298,6 @@ def get_schema(
if not min_maxes:
continue

if casts:
sql_expr = SQLCast(columns=casts, force=True).apply(sql_expr)
minmax_expr = SQLMinMax(columns=min_maxes).apply(sql_expr)
minmax_expr = ' '.join(minmax_expr.splitlines())
minmax_data = self._connection.execute(minmax_expr).fetch_df()
Expand Down
5 changes: 2 additions & 3 deletions lumen/sources/intake_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_schema(
tables = [table]

schemas = {}
sql_limit = SQLLimit(limit=limit or 3)
sql_limit = SQLLimit(limit=limit or 1)
for entry in tables:
if not self.load_schema:
schemas[entry] = {}
Expand Down Expand Up @@ -116,9 +116,8 @@ def get_schema(
continue

# Calculate numeric schemas
transforms = [SQLMinMax(columns=min_maxes)]
minmax_data = self._read(
self._apply_transforms(source, transforms)
self._apply_transforms(source, [SQLMinMax(columns=min_maxes)])
)
for col in min_maxes:
kind = data[col].dtype.kind
Expand Down
21 changes: 0 additions & 21 deletions lumen/transforms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,6 @@ def apply(self, sql_in):
return self._render_template(template, sql_in=sql_in, columns=', '.join(map(quote, self.columns)))


class SQLCast(SQLTransform):

columns = param.Dict(default={})

force = param.Boolean(default=False, doc="Whether to cast to null if it can't be converted.")

transform_type: ClassVar[str] = 'sql_cast'

def apply(self, sql_in):
if not self.columns:
return sql_in
sql_in = super().apply(sql_in)
cast_statement = 'TRY_CAST' if self.force else 'CAST'
casts = ",".join(
f'{cast_statement}({col} as {cast_type}) as {col}'
for col, cast_type in self.columns.items()
)
template = "SELECT *, {{casts}} FROM ({{sql_in}})"
return self._render_template(template, sql_in=sql_in, casts=casts)


class SQLFilter(SQLTransform):
"""
Translates Lumen Filter query into a SQL WHERE statement.
Expand Down

0 comments on commit eaa1cff

Please sign in to comment.