Skip to content

Commit

Permalink
add pyarrow support in find_column_type for pandas dataframes (#313)
Browse files Browse the repository at this point in the history
* add pyarrow support in find_column_type for pandas dataframes

* update pandas lower pin

* change default to varchar
  • Loading branch information
gladysteh99 authored Nov 19, 2024
1 parent ee7f51b commit 6b0f1f8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 18 deletions.
53 changes: 36 additions & 17 deletions locopy/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import yaml

from locopy.errors import (
Expand Down Expand Up @@ -317,6 +318,20 @@ def validate_float_object(column):
except (ValueError, TypeError):
return None

def check_column_type_pyarrow(pa_dtype):
if pa.types.is_temporal(pa_dtype):
return "timestamp"
elif pa.types.is_boolean(pa_dtype):
return "boolean"
elif pa.types.is_integer(pa_dtype):
return "int"
elif pa.types.is_floating(pa_dtype):
return "float"
elif pa.types.is_string(pa_dtype):
return "varchar"
else:
return "varchar"

if warehouse_type.lower() not in ["snowflake", "redshift"]:
raise ValueError(
'warehouse_type argument must be either "snowflake" or "redshift"'
Expand All @@ -328,24 +343,28 @@ def validate_float_object(column):
data = dataframe[column].dropna().reset_index(drop=True)
if data.size == 0:
column_type.append("varchar")
elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (
re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype))
):
column_type.append("timestamp")
elif str(data.dtype).lower().startswith("bool"):
column_type.append("boolean")
elif str(data.dtype).startswith("object"):
data_type = validate_float_object(data) or validate_date_object(data)
if not data_type:
column_type.append("varchar")
else:
column_type.append(data_type)
elif str(data.dtype).lower().startswith("int"):
column_type.append("int")
elif str(data.dtype).lower().startswith("float"):
column_type.append("float")
elif isinstance(data.dtype, pd.ArrowDtype):
datatype = check_column_type_pyarrow(data.dtype.pyarrow_dtype)
column_type.append(datatype)
else:
column_type.append("varchar")
if (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (
re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype))
):
column_type.append("timestamp")
elif str(data.dtype).lower().startswith("bool"):
column_type.append("boolean")
elif str(data.dtype).startswith("object"):
data_type = validate_float_object(data) or validate_date_object(data)
if not data_type:
column_type.append("varchar")
else:
column_type.append(data_type)
elif str(data.dtype).lower().startswith("int"):
column_type.append("int")
elif str(data.dtype).lower().startswith("float"):
column_type.append("float")
else:
column_type.append("varchar")
logger.info("Parsing column %s to %s", column, column_type[-1])
return OrderedDict(zip(list(dataframe.columns), column_type))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" },
]
license = {text = "Apache Software License"}
dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=0.25.2", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0"]
dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=1.5.0", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0", "pyarrow>=10.0.1"]

requires-python = ">=3.9.0"
classifiers = [
Expand Down
42 changes: 42 additions & 0 deletions tests/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from unittest import mock

import locopy.utility as util
import pyarrow as pa
import pytest
from locopy.errors import (
CompressionError,
Expand Down Expand Up @@ -388,7 +389,48 @@ def test_find_column_type_new():
"d": "varchar",
"e": "boolean",
}
assert find_column_type(input_text, "snowflake") == output_text_snowflake
assert find_column_type(input_text, "redshift") == output_text_redshift


def test_find_column_type_pyarrow():
import pandas as pd

input_text = pd.DataFrame.from_dict(
{
"a": [1],
"b": [pd.Timestamp("2017-01-01T12+0")],
"c": [1.2],
"d": ["a"],
"e": [True],
}
)

input_text = input_text.astype(
dtype={
"a": "int64[pyarrow]",
"b": "timestamp[ns, tz=UTC][pyarrow]",
"c": "float64[pyarrow]",
"d": pd.ArrowDtype(pa.string()),
"e": "bool[pyarrow]",
}
)

output_text_snowflake = {
"a": "int",
"b": "timestamp",
"c": "float",
"d": "varchar",
"e": "boolean",
}

output_text_redshift = {
"a": "int",
"b": "timestamp",
"c": "float",
"d": "varchar",
"e": "boolean",
}
assert find_column_type(input_text, "snowflake") == output_text_snowflake
assert find_column_type(input_text, "redshift") == output_text_redshift

Expand Down

0 comments on commit 6b0f1f8

Please sign in to comment.