diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 7a2d4961..6ffaab65 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -247,3 +247,30 @@ def test_XY_dataset(): clf.fit(data) assert clf.score(data, Y) > 0.6 + + +def test_XY_dataset_sparse_y(): + X = [ + "One and two", + "One only", + "Two nothing else", + "Two and three" + ] + Y = np.array([ + [1, 1, 0, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 1, 0] + ]) + Y_sparse = csr_matrix(Y) + + vec = KerasVectorizer() + X_vec = vec.fit_transform(X) + + data = tf.data.Dataset.from_tensor_slices((X_vec, Y)) + data = data.shuffle(100) + clf = CNNClassifier( + batch_size=2, sparse_y=True, multilabel=True + ) + clf.fit(data) + assert clf.score(data, Y_sparse) > 0.3