Skip to content

Commit

Permalink
don't squeeze to hard! Ensuring atleast1d
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Nov 20, 2024
1 parent 8c282a4 commit e279563
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def _predict(
test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False)
predictions = self.lightning_trainer.predict(self.model, test_data)
prediction_array = np.vstack(predictions) # type: ignore
prediction_array = prediction_array.squeeze()

prediction_array = np.atleast_1d(prediction_array.squeeze())
# Check if the predictions have the same length as the input dataset
if prediction_array.shape[0] != len(X):
raise AssertionError(
Expand Down

0 comments on commit e279563

Please sign in to comment.