Skip to content

Commit

Permalink
Fix dataclass field type in _row_to_sql (#266)
Browse files Browse the repository at this point in the history
In complement to #257 , dataclass field type needs to be adjusted in
other places.
This fix is required by databrickslabs/ucx#2526

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Sep 4, 2024
1 parent 4c8494b commit 21995fd
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,7 @@ def create_table(self, full_name: str, klass: Dataclass):
def _schema_for(cls, klass: Dataclass):
fields = []
for f in dataclasses.fields(klass):
field_type = f.type
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field_type, str):
try:
field_type = __builtins__[field_type]
except TypeError as e:
logger.warning(f"Could not load type {field_type}", exc_info=e)
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if field_type not in cls._builtin_type_mapping:
Expand All @@ -89,6 +82,17 @@ def _schema_for(cls, klass: Dataclass):
fields.append(f"{f.name} {spark_type}{not_null}")
return ", ".join(fields)

@classmethod
def _field_type(cls, field: dataclasses.Field):
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field.type, str):
try:
return __builtins__[field.type]
except TypeError as e:
logger.warning(f"Could not load type {field.type}", exc_info=e)
return field.type

@classmethod
def _filter_none_rows(cls, rows, klass):
if len(rows) == 0:
Expand Down Expand Up @@ -168,12 +172,12 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})'
self.execute(sql)

@staticmethod
def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
@classmethod
def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
data = []
for f in fields:
value = getattr(row, f.name)
field_type = f.type
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if value is None:
Expand Down

0 comments on commit 21995fd

Please sign in to comment.