diff --git a/.gitignore b/.gitignore index 1dcc7e4..ea34492 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ venv/ .vscode/ dev/ +service_accounts/ build/ dist/ diff --git a/README.md b/README.md index 3fd48fd..ebc7199 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Since Polars leverages Rust speedups, you need to have Rust installed in your en ## Usage -In this demo we'll connect to BigQuery, read data, transform it, and write it back to the data warehouse. +In this demo we'll connect to Google BigQuery, read data, transform it, and write it back to the data warehouse. First, connect to the BigQuery warehouse by supplying the `BigQueryConnector()` object with the relative path to your service account credentials. @@ -32,7 +32,7 @@ bq = BigQueryConnector( ) ``` -Next, supply the object with a SQL query in the `read_dataframe_from_bigquery()` function to redner a `DataFrame` object: +Next, supply the object with a SQL query in the `read_dataframe()` function to redner a `DataFrame` object: ``` # Write some valid SQL @@ -45,7 +45,7 @@ ORDER BY avg_points DESC # Pull BigQuery data into a Polars DataFrame -nyk = bq.read_dataframe_from_bigquery(sql=sql) +nyk = bq.read_dataframe(sql=sql) ``` Now that your data is pulled into a local instance, you can clean and transform it using standard Polars functionality - [see the docs](https://docs.pola.rs/py-polars/html/reference/dataframe/index.html) for more information. @@ -61,11 +61,11 @@ key_metrics = [ summary_stats = nyk[key_metrics].describe() ``` -Finally, push your transformed data back to the BigQuery warehouse using the `write_dataframe_to_bigquery()` function: +Finally, push your transformed data back to the BigQuery warehouse using the `write_dataframe()` function: ``` # Write back to BigQuery -bq.write_dataframe_to_bigquery( +bq.write_dataframe( df=summary_stats, table_name="nba_dbt.summary_statistics", if_exists="truncate" diff --git a/klondike/__init__.py b/klondike/__init__.py index 7e39c33..5ef879f 100644 --- a/klondike/__init__.py +++ b/klondike/__init__.py @@ -25,6 +25,7 @@ __all__ = [] POLAR_OBJECTS = [ ("klondike.bigquery.bigquery", "BigQueryConnector"), + ("klondike.snowflake.snowflake", "SnowflakeConnector"), ] for module_, object_ in POLAR_OBJECTS: diff --git a/klondike/bigquery/bigquery.py b/klondike/bigquery/bigquery.py index 30e70c6..c4cec3f 100644 --- a/klondike/bigquery/bigquery.py +++ b/klondike/bigquery/bigquery.py @@ -7,8 +7,10 @@ import polars as pl from google.cloud import bigquery from google.cloud.bigquery import LoadJobConfig +from google.cloud.exceptions import NotFound from klondike import logger +from klondike.utilities.utilities import validate_if_exists_behavior ########## @@ -26,12 +28,18 @@ class BigQueryConnector: Establish and authenticate a connection to a BigQuery warehouse Args: - app_creds: Google service account, either as a relative path or a dictionary instance - project: Name of Google Project - location: Location of Google Project - timeout: Temporal threshold to kill a stalled job, defaults to 60s - client_options: API scopes - google_environment_variable: Provided for flexibility, defaults to `GOOGLE_APPLICATION_CREDENTIALS` + app_creds: `str` + Google service account, either as a relative path or a dictionary instance + project: `str` + Name of Google Project + location: `str` + Location of Google Project + timeout: `int` + Temporal threshold to kill a stalled job, defaults to 60s + client_options: `list` + API scopes + google_environment_variable: `str` + Provided for flexibility, defaults to `GOOGLE_APPLICATION_CREDENTIALS` """ def __init__( @@ -46,11 +54,11 @@ def __init__( self.app_creds = app_creds self.project = project self.location = location - self.timeout = timeout self.client_options = client_options - self._client = None self.dialect = "bigquery" + self.__client = None + self.__timeout = timeout if not self.app_creds: if not os.environ.get(google_environment_variable): @@ -64,6 +72,30 @@ def __init__( app_creds=self.app_creds, env_variable=google_environment_variable ) + @property + def client(self): + """ + Instantiate BigQuery client and assign it + as class property + """ + + if not self.__client: + self.__client = bigquery.Client( + project=self.project, + location=self.location, + client_options=self.client_options, + ) + + return self.__client + + @property + def timeout(self): + return self.__timeout + + @timeout.setter + def timeout(self, timeout): + self.__timeout = timeout + def __setup_google_app_creds(self, app_creds: Union[str, dict], env_variable: str): "Sets runtime environment variable for Google SDK" @@ -89,6 +121,20 @@ def __set_load_job_config( ): "Defines `LoadConfigJob` when writing to BigQuery" + def set_write_disposition(if_exists: str): + DISPOSITION_MAP = { + "fail": bigquery.WriteDisposition.WRITE_EMPTY, + "append": bigquery.WriteDisposition.WRITE_APPEND, + "truncate": bigquery.WriteDisposition.WRITE_TRUNCATE, + } + + return DISPOSITION_MAP[if_exists] + + def set_table_schema(table_schema: list): + return [bigquery.SchemaField(**x) for x in table_schema] + + ### + if not base_job_config: logger.debug("No job config provided, starting fresh") base_job_config = LoadJobConfig() @@ -98,17 +144,16 @@ def __set_load_job_config( # Create table schema mapping if provided if table_schema: - base_job_config.schema = self.__set_table_schema(table_schema=table_schema) + base_job_config.schema = set_table_schema(table_schema=table_schema) else: base_job_config.schema = None base_job_config.max_bad_records = max_bad_records + base_job_config.write_disposition = set_write_disposition(if_exists=if_exists) - base_job_config.write_disposition = self.__set_write_disposition( - if_exists=if_exists - ) + ### - # List of LoadJobConfig attributes + # List of available LoadJobConfig attributes _attributes = [x for x in dict(vars(LoadJobConfig)).keys()] # Attributes that will not be overwritten @@ -127,44 +172,14 @@ def __set_load_job_config( return base_job_config - def __set_table_schema(self, table_schema: list): - "TODO - Write about me" - - return [bigquery.SchemaField(**x) for x in table_schema] - - def __set_write_disposition(self, if_exists: str): - "TODO - Write about me" - - DISPOSITION_MAP = { - "fail": bigquery.WriteDisposition.WRITE_EMPTY, - "append": bigquery.WriteDisposition.WRITE_APPEND, - "truncate": bigquery.WriteDisposition.WRITE_TRUNCATE, - } - - return DISPOSITION_MAP[if_exists] - - @property - def client(self): - """ - Instantiate BigQuery client - """ - - if not self._client: - self._client = bigquery.Client( - project=self.project, - location=self.location, - client_options=self.client_options, - ) - - return self._client - - def read_dataframe_from_bigquery(self, sql: str) -> pl.DataFrame: + def read_dataframe(self, sql: str) -> pl.DataFrame: """ Executes a SQL query and returns a Polars DataFrame. TODO - Make this more flexible and incorporate query params Args: - sql: String representation of SQL query + sql: `str` + String representation of SQL query Returns: Polars DataFrame object @@ -185,9 +200,10 @@ def read_dataframe_from_bigquery(self, sql: str) -> pl.DataFrame: return logger.info(f"Successfully read {len(df)} rows from BigQuery") + return df - def write_dataframe_to_bigquery( + def write_dataframe( self, df: pl.DataFrame, table_name: str, @@ -201,16 +217,26 @@ def write_dataframe_to_bigquery( Writes a Polars DataFrame to BigQuery Args: - df: Polars DataFrame - table_name: Destination table name to write to - `dataset.table` convention - load_job_config: `LoadJobConfig` object. If none is supplied, several defaults are applied - max_bad_records: Tolerance for bad records in the load job, defaults to 0 - table_schema: List of column names, types, and optional flags to include - if_exists: One of `fail`, `drop`, `append`, `truncate` - load_kwargs: See here for list of accepted values - https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig + df: `polars.DataFrame` + DataFrame to write to BigQuery + table_name: `str` + Destination table name to write to - `dataset.table` convention + load_job_config: `LoadJobConfig` + Configures load job; if none is supplied, several defaults are applied + max_bad_records: `int` + Tolerance for bad records in the load job, defaults to 0 + table_schema: `list` + List of column names, types, and optional flags to include + if_exists: `str` + One of `fail`, `drop`, `append`, `truncate` + load_kwargs: + See here for list of accepted values \ + https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig """ + if not validate_if_exists_behavior(user_input=if_exists): + raise ValueError(f"{if_exists} is an invalid input") + if if_exists == "drop": self.client.delete_table(table=table_name) if_exists = "fail" @@ -233,3 +259,36 @@ def write_dataframe_to_bigquery( load_job.result() logger.info(f"Successfuly wrote {len(df)} rows to {table_name}") + + def table_exists(self, table_name: str) -> bool: + """ + Determines if a BigQuery table exists + + Args: + table_name: `str` + BigQuery table name in `schema.table` or `project.schema.table` format + """ + + try: + _ = self.client.get_table(table=table_name) + return True + + except NotFound: + return False + + def list_tables(self, schema_name: str) -> list: + """ + Gets a list of available tables in a BigQuery schema + + Args: + schema_name: `str` + BigQuery schema name + + Returns: + List of table names + """ + + return [ + x.full_table_id.replace(":", ".") + for x in self.client.list_tables(dataset=schema_name) + ] diff --git a/klondike/utils/__init__.py b/klondike/snowflake/__init__.py similarity index 100% rename from klondike/utils/__init__.py rename to klondike/snowflake/__init__.py diff --git a/klondike/snowflake/snowflake.py b/klondike/snowflake/snowflake.py new file mode 100644 index 0000000..2e6eec3 --- /dev/null +++ b/klondike/snowflake/snowflake.py @@ -0,0 +1,349 @@ +import os +from contextlib import contextmanager +from typing import Optional + +import polars as pl +import snowflake.connector as snow +from snowflake.connector.pandas_tools import write_pandas + +from klondike import logger +from klondike.utilities.utilities import validate_if_exists_behavior + +########## + + +class SnowflakeConnector: + """ + Leverages connection to Snowflake to read and write Polars DataFrame + objects to the data warehouse + + `Args`: + snowflake_user: `str` + Username to connect to Snowflake (defaults to `SNOWFLAKE_USER` in environment) + snowflake_password: `str` + Password to connect to Snowflake (defaults to `SNOWFLAKE_PASSWORD` in environment) + snowflake_account: `str` + Account identifier for Snowflake warehouse (defaults to `SNOWFLAKE_ACCOUNT` in environment) + snowflake_warehouse: `str` + Snowflake warehouse name (defaults to `SNOWFLAKE_WAREHOUSE` in environment) + snowflake_database: `str` + Snowflake database name (defaults to `SNOWFLAKE_DATABASE` in environment) + row_chunk_size: `int` + Default row chunk size for reading from / writing to Snowflake + """ + + def __init__( + self, + snowflake_user: Optional[str] = None, + snowflake_password: Optional[str] = None, + snowflake_account: Optional[str] = None, + snowflake_database: Optional[str] = None, + snowflake_warehouse: Optional[str] = None, + row_chunk_size: int = 100_000, + ): + """ + All Snowflake connection values either need to be supplied as constructor + arguments or inferred from the environment; if neither occurs, a `ValueError` + will be raised + """ + + self.snowflake_user = ( + snowflake_user if snowflake_user else os.getenv("SNOWFLAKE_USER") + ) + self.snowflake_password = ( + snowflake_password + if snowflake_password + else os.getenv("SNOWFLAKE_PASSWORD") + ) + self.snowflake_account = ( + snowflake_account if snowflake_account else os.getenv("SNOWFLAKE_ACCOUNT") + ) + self.__snowflake_warehouse = ( + snowflake_warehouse + if snowflake_warehouse + else os.getenv("SNOWFLAKE_WAREHOUSE") + ) + self.__snowflake_database = ( + snowflake_database + if snowflake_database + else os.getenv("SNOWFLAKE_DATABASE") + ) + + ### + + self.__validate_authentication() + + ### + + self.dialect = "snowflake" + self.__row_chunk_size = row_chunk_size + + def __validate_authentication(self): + _auth_vals = [ + self.snowflake_user, + self.snowflake_password, + self.snowflake_account, + self.snowflake_database, + self.snowflake_warehouse, + ] + + if any([not x for x in _auth_vals]): + raise ValueError( + "Missing authentication values! Make sure all `snowflake_*` values are provided at construction" + ) + + @property + def snowflake_warehouse(self): + return self.__snowflake_warehouse + + @snowflake_warehouse.setter + def snowflake_warehouse(self, warehouse): + self.__snowflake_warehouse = warehouse + + @property + def snowflake_database(self): + return self.__snowflake_database + + @snowflake_database.setter + def snowflake_database(self, database): + self.__snowflake_database = database + + @property + def row_chunk_size(self): + return self.__row_chunk_size + + @row_chunk_size.setter + def row_chunk_size(self, row_chunk_size): + self.__row_chunk_size = row_chunk_size + + @contextmanager + def connection(self): + """ + Creates a connection to Snowflake + """ + + conn = snow.connect( + account=self.snowflake_account, + warehouse=self.snowflake_warehouse, + user=self.snowflake_user, + password=self.snowflake_password, + ) + + try: + yield conn + conn.commit() + finally: + conn.close() + + @contextmanager + def cursor(self, connection): + """ + Leverages Snowflake connection to execute SQL transactions + """ + + cur = connection.cursor() + + try: + yield cur + finally: + cur.close() + + def __query(self, sql: str): + """ + Executes SQL command against Snowflake warehouse + + Args: + sql: `str` + SQL query in string format + """ + + with self.connection() as _conn: + with self.cursor(_conn) as _cursor: + # TODO - Ugly! Clean this up + _cursor.execute(f"USE DATABASE {self.snowflake_database};") + _cursor.execute(f"USE WAREHOUSE {self.snowflake_warehouse};") + _cursor.execute(sql) + _resp = _cursor.fetch_arrow_batches() + + try: + return pl.from_arrow(_resp) + except ValueError as ve: + # NOTE - This appears to be Polars response to empty fetch_arrow_batches() + # This should be interrogated more, but is functional + if "Must pass schema, or at least one RecordBatch" in str(ve): + logger.debug("No results obtained via query") + return pl.DataFrame() + raise + + def read_dataframe(self, sql: str) -> pl.DataFrame: + """ + Executes a SQL query against the Snowflake warehouse + and returns the result as a Polars DataFrame object + + Args: + sql: `str` + String representation of SQL query + + Returns: + Polars DataFrame object + """ + + # Execute SQL against warehouse + logger.debug("Running SQL...", sql) + df = self.__query(sql=sql) + + logger.info(f"Successfully read {len(df)} rows from Snowflake") + + return df + + def write_dataframe( + self, + df: pl.DataFrame, + table_name: str, + schema_name: str, + database_name: Optional[str] = None, + if_exists: str = "append", + auto_create_table: bool = True, + chunk_output: bool = False, + ) -> None: + """ + Writes a Polars DataFrame to Snowflake + + TODO - Utilizing `.to_pandas()` may not be ideal, we should consider + modifying this function to utilize `PyArrow` under the hood (see - `klondike.BigQueryConnector`) + + Args: + df: `pl.DataFrame` + Polars DataFrame to be written to Snowflake + table_name: `str` + Destination table name + database_name: `str` + Destination database name + schema_name: `str` + Destination schema name + database: `str` + Optional Snowflake database (defaults to preset class attribute) + if_exists: `str` + One of `append`, `truncate`, `drop`, `fail` + auto_create_table: `bool` + If true, the desintation table will be created if it doesn't already exist + chunk_output: `int` + If true, the default chunk size will be applied; otherwise no chunking will occur + """ + + # Use class attribute for row chunking + row_chunks = self.row_chunk_size if chunk_output else None + + # Confirm that user has passed in valid logic + if not validate_if_exists_behavior(if_exists): + raise ValueError(f"{if_exists} is an invalid input") + + # This logic leverages the behavior of the write_pandas() function ... see below + if if_exists == "drop" and auto_create_table: + overwrite = True + logger.warning(f"{table_name} will be dropped if it exists") + + elif if_exists == "truncate": + overwrite = True + auto_create_table = False + + elif if_exists == "append": + overwrite = False + auto_create_table = False + + elif if_exists == "fail": + if self.table_exists(schema_name=schema_name, table_name=table_name): + raise snow.errors.DatabaseError(f"{table_name} already exists") + + database = database_name if database_name else self.snowflake_database + + ### + + logger.info( + f"Writing to {self.snowflake_database}.{schema_name}.{table_name}..." + ) + with self.connection() as conn: + resp, num_chunks, num_rows, output = write_pandas( + conn=conn, + df=df.to_pandas(), + database=database, + schema=schema_name, + table_name=table_name, + auto_create_table=auto_create_table, + chunk_size=row_chunks, + overwrite=overwrite, + ) + + ### + + if resp: + logger.info(f"Successfully wrote {num_rows} rows to {table_name}") + else: + logger.error(f"Failed to write to {table_name}", resp) + raise + + def table_exists( + self, table_name: str, database_name: Optional[str] = None + ) -> bool: + """ + Determines if a Snowflake table exists in the warehouse + + Args: + table_name: `str` + Snowflake table name in `schema.table` format + + Returns: + True if the table exists, False otherwise + """ + + schema, table = table_name.split(".") + + if not database_name: + logger.debug(f"Defaulting to default database [{self.snowflake_database}]") + database_name = self.snowflake_database + + sql = f""" + SELECT + * + FROM INFORMATION_SCHEMA.TABLES + WHERE table_schema = '{schema}' + AND table_name = '{table}' + AND table_catalog = '{database_name}' + """ + + resp = self.__query(sql=sql) + + return not resp.is_empty() + + def list_tables( + self, schema_name: str, database_name: Optional[str] = None + ) -> list: + """ + Gets a list of available tables in a Snowflake schema + + Args: + schema_name: `str` + database_name: `str` + + Returns: + List of table names + """ + + if not database_name: + database_name = self.snowflake_database + + sql = f""" + SELECT + table_catalog, + table_schema, + table_name + FROM INFORMATION_SCHEMA.TABLES + WHERE table_schema = '{schema_name}' + AND table_catalog = '{database_name}' + """ + + resp = self.__query(sql=sql) + + resp = resp.select(table_name=pl.concat_str(resp.columns, separator=".")) + + return pl.Series(resp["table_name"]).to_list() diff --git a/klondike/utilities/__init__.py b/klondike/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/klondike/utilities/utilities.py b/klondike/utilities/utilities.py new file mode 100644 index 0000000..8a22cf6 --- /dev/null +++ b/klondike/utilities/utilities.py @@ -0,0 +1,13 @@ +""" +General utilites to use across all database connections +""" + + +def validate_if_exists_behavior( + user_input: str, acceptable_values: list = ["fail", "append", "truncate", "drop"] +): + """ + Ensures that user input to `write_dataframe` function is valid + """ + + return user_input in acceptable_values diff --git a/requirements.txt b/requirements.txt index f73718d..d73457e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ google-cloud-bigquery==3.19.0 +snowflake-connector-python==3.10.1 polars==0.20.16 pyarrow==15.0.2 \ No newline at end of file diff --git a/setup.py b/setup.py index c2fd8cd..a6ba526 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ -import setuptools import os import pathlib +import setuptools + ########## @@ -14,11 +15,11 @@ def main(): setuptools.setup( name="klondike", - version="0.1.0", + version="0.2.0", author="Ian Richard Ferguson", author_email="IRF229@nyu.edu", url="https://github.com/IanRFerguson/klondike", - keywords=["API", "ETL", "BIGQUERY"], + keywords=["API", "ETL", "BIGQUERY", "SNOWFLAKE"], packages=setuptools.find_packages(), install_requires=INSTALL_REQUIRES, classifiers=[ diff --git a/tests/test_bigquery_connector.py b/tests/test_bigquery_connector.py deleted file mode 100644 index 23d5789..0000000 --- a/tests/test_bigquery_connector.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -from unittest import mock - -import polars as pl - -from klondike import BigQueryConnector - -from .test_utils import KlondikeTestCase - -########## - - -class TestBigQuery(KlondikeTestCase): - def setUp(self): - super().setUp() - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path - - def tearDown(self): - super().tearDown() - del os.environ["GOOGLE_APPLICATION_CREDENTIALS"] - - def _build_mock_cursor(self, query_results=None): - cursor = mock.MagicMock() - cursor.execute.return_value = None - cursor.fetchmany.side_effect = [query_results, []] - - if query_results is not None: - cursor.description = query_results - - # Create a mock that will play the role of the connection - connection = mock.MagicMock() - connection.cursor.return_value = cursor - - # Create a mock that will play the role of our GoogleBigQuery client - client = mock.MagicMock() - - bq = BigQueryConnector() - bq.connection = connection - bq._client = client - - return bq - - @mock.patch("polars.from_arrow") - def test_read_dataframe_from_bigquery(self, mock_from_arrow): - "Tests read functionality for the `BigQueryConnector` object" - - sql = "select * from my_table" - tbl = pl.DataFrame( - [ - {"city": "Brooklyn", "state": "New York"}, - {"city": "San Francisco", "state": "California"}, - {"city": "Richmond", "state": "Virginia"}, - ] - ) - - bq = self._build_mock_cursor(query_results=tbl) - df = bq.read_dataframe_from_bigquery(sql=sql) - - assert isinstance(df, type(None)) - - @mock.patch("polars.DataFrame.write_parquet") - def test_write_dataframe_to_bigquery(self, mock_write_parquet): - "Tests write functionality for the `BigQueryConnector` object" - - df = mock.MagicMock() - table_name = "foo.bar" - - bq = self._build_mock_cursor() - bq.write_dataframe_to_bigquery(df=df, table_name=table_name) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index d16381b..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -import json -import os -import tempfile -import unittest - -########## - - -class KlondikeTestCase(unittest.TestCase): - def setUp(self): - self._temp_directory = tempfile.TemporaryDirectory() - - self._credentials_path = os.path.join( - self._temp_directory.name, "service_account.json" - ) - - self._service_account = { - "type": "foo", - "project_id": "bar", - "private_key_id": "biz", - "private_key": "bap", - "client_email": "bim", - "client_id": "top", - "auth_uri": "hat", - "token_uri": "tap", - "auth_provider_x509_cert_url": "dance", - "client_x509_cert_url": "good", - "universe_domain": "stuff", - } - - with open(self._credentials_path, "w") as f: - json.dump(self._service_account, f) - - def tearDown(self): - self._temp_directory.cleanup()