diff --git a/locopy/utility.py b/locopy/utility.py index 7b61f31..22cd6c0 100644 --- a/locopy/utility.py +++ b/locopy/utility.py @@ -30,6 +30,7 @@ import pandas as pd import polars as pl +import pyarrow as pa import yaml from locopy.errors import ( @@ -317,6 +318,20 @@ def validate_float_object(column): except (ValueError, TypeError): return None + def check_column_type_pyarrow(pa_dtype): + if pa.types.is_temporal(pa_dtype): + return "timestamp" + elif pa.types.is_boolean(pa_dtype): + return "boolean" + elif pa.types.is_integer(pa_dtype): + return "int" + elif pa.types.is_floating(pa_dtype): + return "float" + elif pa.types.is_string(pa_dtype): + return "varchar" + else: + return "varchar" + if warehouse_type.lower() not in ["snowflake", "redshift"]: raise ValueError( 'warehouse_type argument must be either "snowflake" or "redshift"' @@ -328,24 +343,28 @@ def validate_float_object(column): data = dataframe[column].dropna().reset_index(drop=True) if data.size == 0: column_type.append("varchar") - elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or ( - re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype)) - ): - column_type.append("timestamp") - elif str(data.dtype).lower().startswith("bool"): - column_type.append("boolean") - elif str(data.dtype).startswith("object"): - data_type = validate_float_object(data) or validate_date_object(data) - if not data_type: - column_type.append("varchar") - else: - column_type.append(data_type) - elif str(data.dtype).lower().startswith("int"): - column_type.append("int") - elif str(data.dtype).lower().startswith("float"): - column_type.append("float") + elif isinstance(data.dtype, pd.ArrowDtype): + datatype = check_column_type_pyarrow(data.dtype.pyarrow_dtype) + column_type.append(datatype) else: - column_type.append("varchar") + if (data.dtype in ["datetime64[ns]", "M8[ns]"]) or ( + re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype)) + ): + column_type.append("timestamp") + elif str(data.dtype).lower().startswith("bool"): + column_type.append("boolean") + elif str(data.dtype).startswith("object"): + data_type = validate_float_object(data) or validate_date_object(data) + if not data_type: + column_type.append("varchar") + else: + column_type.append(data_type) + elif str(data.dtype).lower().startswith("int"): + column_type.append("int") + elif str(data.dtype).lower().startswith("float"): + column_type.append("float") + else: + column_type.append("varchar") logger.info("Parsing column %s to %s", column, column_type[-1]) return OrderedDict(zip(list(dataframe.columns), column_type)) diff --git a/pyproject.toml b/pyproject.toml index 91bf317..43d6bf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" }, ] license = {text = "Apache Software License"} -dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=0.25.2", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0"] +dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=1.5.0", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0", "pyarrow>=10.0.1"] requires-python = ">=3.9.0" classifiers = [ diff --git a/tests/test_utility.py b/tests/test_utility.py index ec6f8f2..a7991f8 100644 --- a/tests/test_utility.py +++ b/tests/test_utility.py @@ -29,6 +29,7 @@ from unittest import mock import locopy.utility as util +import pyarrow as pa import pytest from locopy.errors import ( CompressionError, @@ -388,7 +389,48 @@ def test_find_column_type_new(): "d": "varchar", "e": "boolean", } + assert find_column_type(input_text, "snowflake") == output_text_snowflake + assert find_column_type(input_text, "redshift") == output_text_redshift + + +def test_find_column_type_pyarrow(): + import pandas as pd + + input_text = pd.DataFrame.from_dict( + { + "a": [1], + "b": [pd.Timestamp("2017-01-01T12+0")], + "c": [1.2], + "d": ["a"], + "e": [True], + } + ) + input_text = input_text.astype( + dtype={ + "a": "int64[pyarrow]", + "b": "timestamp[ns, tz=UTC][pyarrow]", + "c": "float64[pyarrow]", + "d": pd.ArrowDtype(pa.string()), + "e": "bool[pyarrow]", + } + ) + + output_text_snowflake = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } + + output_text_redshift = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } assert find_column_type(input_text, "snowflake") == output_text_snowflake assert find_column_type(input_text, "redshift") == output_text_redshift