Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into scikit-build-core-version
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd authored Aug 9, 2024
2 parents bdfbff7 + 19ffd6b commit 5ea986b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/cuml/cuml/dask/preprocessing/LabelEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class LabelEncoder(
0 a
1 a
2 b
0 c
1 b
3 c
4 b
dtype: object
>>> client.close()
>>> cluster.close()
Expand Down
20 changes: 11 additions & 9 deletions python/cuml/cuml/preprocessing/LabelEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,9 @@ def transform(self, y) -> cudf.Series:

y = cudf.Series(y, dtype="category")

encoded = y.cat.set_categories(self.classes_)._column.codes
encoded = cudf.Series(encoded, index=y.index)
encoded = y.cat.set_categories(self.classes_).cat.codes

if encoded.has_nulls and self.handle_unknown == "error":
if encoded.hasnans and self.handle_unknown == "error":
raise KeyError("Attempted to encode unseen key")

return encoded
Expand All @@ -237,9 +236,9 @@ def fit_transform(self, y, z=None) -> cudf.Series:
self.dtype = y.dtype if y.dtype != cp.dtype("O") else str

y = y.astype("category")
self.classes_ = y._column.categories
self.classes_ = y.cat.categories

return cudf.Series(y._column.codes, index=y.index)
return y.cat.codes

def inverse_transform(self, y: cudf.Series) -> cudf.Series:
"""
Expand Down Expand Up @@ -275,11 +274,14 @@ def inverse_transform(self, y: cudf.Series) -> cudf.Series:

y = y.astype(self.dtype)

ran_idx = cudf.Series(cp.arange(len(self.classes_))).astype(self.dtype)

reverted = y._column.find_and_replace(ran_idx, self.classes_, False)
# TODO: Remove ._column once .replace correctly accepts cudf.Index
ran_idx = (
cudf.Index(cp.arange(len(self.classes_)))
.astype(self.dtype)
._column
)
res = y.replace(ran_idx, self.classes_)

res = cudf.Series(reverted)
return res

def get_param_names(self):
Expand Down

0 comments on commit 5ea986b

Please sign in to comment.