Skip to content
This repository has been archived by the owner on Aug 9, 2023. It is now read-only.

Commit

Permalink
Add test for CNN where X td.data.Dataset and sparse_y=True
Browse files Browse the repository at this point in the history
  • Loading branch information
nsorros committed Feb 25, 2021
1 parent f9521c7 commit 3c7df33
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,30 @@ def test_XY_dataset():

clf.fit(data)
assert clf.score(data, Y) > 0.6


def test_XY_dataset_sparse_y():
X = [
"One and two",
"One only",
"Two nothing else",
"Two and three"
]
Y = np.array([
[1, 1, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 1, 0]
])
Y_sparse = csr_matrix(Y)

vec = KerasVectorizer()
X_vec = vec.fit_transform(X)

data = tf.data.Dataset.from_tensor_slices((X_vec, Y))
data = data.shuffle(100)
clf = CNNClassifier(
batch_size=2, sparse_y=True, multilabel=True
)
clf.fit(data)
assert clf.score(data, Y_sparse) > 0.3

0 comments on commit 3c7df33

Please sign in to comment.