Skip to content

Commit

Permalink
Update SparkDataSet's doc and add support for fsspec (#128)
Browse files Browse the repository at this point in the history
* update notes

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Simplify logic

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Revert "Simplify logic"

This reverts commit 734902a.

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Refactoring

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* more refactoring

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* refactor

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* update datasets

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* fix prefix

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Fix bug and tests

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* update release notes

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Apply comments

---------

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>
  • Loading branch information
noklam authored Mar 24, 2023
1 parent 7b5f222 commit 16bbeb6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
| ------------------------------------ | -------------------------------------------------------------------------- | ----------------------------- |
| `polars.CSVDataSet` | A `CSVDataSet` backed by [polars](https://www.pola.rs/), a lighting fast dataframe package built entirely using Rust. | `kedro_datasets.polars` |
| `snowflake.SnowparkTableDataSet` | Work with [Snowpark](https://www.snowflake.com/en/data-cloud/snowpark/) DataFrames from tables in Snowflake. | `kedro_datasets.snowflake` |
* Use `fsspec` in `SparkDataSet` to support more filesystems.

## Bug fixes and other changes
* Add `mssql` backend to the `SQLQueryDataSet` DataSet using `pyodbc` library.
Expand Down
64 changes: 26 additions & 38 deletions kedro-datasets/kedro_datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def _deployed_on_databricks() -> bool:
return "DATABRICKS_RUNTIME_VERSION" in os.environ


def _path_has_dbfs_prefix(path: str) -> bool:
"""Check if a file path has a valid dbfs prefix.
Args:
path: File path to check.
"""
return path.startswith("/dbfs/")


class KedroHdfsInsecureClient(InsecureClient):
"""Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists``
and ``hdfs_glob`` methods required by ``SparkDataSet``"""
Expand Down Expand Up @@ -287,27 +278,29 @@ def __init__( # pylint: disable=too-many-arguments
"""
credentials = deepcopy(credentials) or {}
fs_prefix, filepath = _split_filepath(filepath)
path = PurePosixPath(filepath)
exists_function = None
glob_function = None

if fs_prefix in ("s3a://", "s3n://"):
if fs_prefix == "s3n://":
warn(
"'s3n' filesystem has now been deprecated by Spark, "
"please consider switching to 's3a'",
DeprecationWarning,
)
if not filepath.startswith("/dbfs/") and _deployed_on_databricks():
logger.warning(
"Using SparkDataSet on Databricks without the `/dbfs/` prefix in the "
"filepath is a known source of error. You must add this prefix to %s",
filepath,
)
if fs_prefix and fs_prefix in ("s3a://"):
_s3 = S3FileSystem(**credentials)
exists_function = _s3.exists
# Ensure cache is not used so latest version is retrieved correctly.
glob_function = partial(_s3.glob, refresh=True)
path = PurePosixPath(filepath)

elif fs_prefix == "hdfs://" and version:
warn(
f"HDFS filesystem support for versioned {self.__class__.__name__} is "
f"in beta and uses 'hdfs.client.InsecureClient', please use with "
f"caution"
)
elif fs_prefix == "hdfs://":
if version:
warn(
f"HDFS filesystem support for versioned {self.__class__.__name__} is "
f"in beta and uses 'hdfs.client.InsecureClient', please use with "
f"caution"
)

# default namenode address
credentials.setdefault("url", "http://localhost:9870")
Expand All @@ -316,21 +309,18 @@ def __init__( # pylint: disable=too-many-arguments
_hdfs_client = KedroHdfsInsecureClient(**credentials)
exists_function = _hdfs_client.hdfs_exists
glob_function = _hdfs_client.hdfs_glob # type: ignore
path = PurePosixPath(filepath)

elif filepath.startswith("/dbfs/"):
# dbfs add prefix to Spark path by default
# See https://github.com/kedro-org/kedro-plugins/issues/117
dbutils = _get_dbutils(self._get_spark())
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
exists_function = partial(_dbfs_exists, dbutils=dbutils)
else:
path = PurePosixPath(filepath)
if _deployed_on_databricks() and not _path_has_dbfs_prefix(filepath):
logger.warning(
"Using SparkDataSet on Databricks without the `/dbfs/` prefix in the "
"filepath is a known source of error. You must add this prefix to %s",
filepath,
)
if filepath.startswith("/dbfs"):
dbutils = _get_dbutils(self._get_spark())
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
exists_function = partial(_dbfs_exists, dbutils=dbutils)
fs = fsspec.filesystem(fs_prefix.strip("://"), **credentials)
exists_function = fs.exists
glob_function = fs.glob

super().__init__(
filepath=path,
Expand Down Expand Up @@ -359,7 +349,6 @@ def __init__( # pylint: disable=too-many-arguments

@staticmethod
def _load_schema_from_file(schema: Dict[str, Any]) -> StructType:

filepath = schema.get("filepath")
if not filepath:
raise DataSetError(
Expand All @@ -375,7 +364,6 @@ def _load_schema_from_file(schema: Dict[str, Any]) -> StructType:

# Open schema file
with file_system.open(load_path) as fs_file:

try:
return StructType.fromJson(json.loads(fs_file.read()))
except Exception as exc:
Expand Down
8 changes: 0 additions & 8 deletions kedro-datasets/tests/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,6 @@ def test_prevent_overwrite(self, mocker, versioned_dataset_s3):

mocked_spark_df.write.save.assert_not_called()

def test_s3n_warning(self, version):
pattern = (
"'s3n' filesystem has now been deprecated by Spark, "
"please consider switching to 's3a'"
)
with pytest.warns(DeprecationWarning, match=pattern):
SparkDataSet(filepath=f"s3n://{BUCKET_NAME}/{FILENAME}", version=version)

def test_repr(self, versioned_dataset_s3, version):
assert "filepath=s3a://" in str(versioned_dataset_s3)
assert f"version=Version(load=None, save='{version.save}')" in str(
Expand Down

0 comments on commit 16bbeb6

Please sign in to comment.