Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow snowflake cursor on write_pandas. #1958

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _create_temp_file_format(


def write_pandas(
conn: SnowflakeConnection,
conn: SnowflakeConnection | SnowflakeCursor,
df: pandas.DataFrame,
table_name: str,
database: str | None = None,
Expand Down Expand Up @@ -216,7 +216,7 @@ def write_pandas(
success, nchunks, nrows, _ = write_pandas(cnx, df, 'customers')

Args:
conn: Connection to be used to communicate with Snowflake.
conn: Connection or Cursor to be used to communicate with Snowflake.
df: Dataframe we'd like to write back.
table_name: Table name where we want to insert into.
database: Database schema and table is in, if not provided the default one will be used (Default value = None).
Expand Down Expand Up @@ -315,7 +315,8 @@ def write_pandas(
else:
sql_use_logical_type = " USE_LOGICAL_TYPE = FALSE"

cursor = conn.cursor()
if isinstance(conn, SnowflakeConnection):
cursor = conn.cursor()
stage_location = _create_temp_stage(
cursor,
database,
Expand Down
14 changes: 14 additions & 0 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,20 @@ def test_empty_dataframe_write_pandas(
), f"sucess: {success}, num_chunks: {num_chunks}, num_rows: {num_rows}"


def test_cursor_write_pandas(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
):
table_name = random_string(5, "empty_dataframe_")
df = pandas.DataFrame([], columns=["name", "balance"])
with conn_cnx() as cnx:
success, num_chunks, num_rows, _ = write_pandas(
cnx.cursor(), df, table_name, auto_create_table=True, table_type="temp"
)
assert (
success and num_chunks == 1 and num_rows == 0
), f"sucess: {success}, num_chunks: {num_chunks}, num_rows: {num_rows}"


@pytest.mark.parametrize(
"database,schema,quote_identifiers,expected_location",
[
Expand Down
Loading