Skip to content

Commit

Permalink
get_beta_softlabels now returns a numpy array
Browse files Browse the repository at this point in the history
  • Loading branch information
victormvy committed Apr 16, 2024
1 parent 5a5243d commit e2b0c83
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion dlordinal/soft_labelling/beta_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from scipy.special import gamma, hyp2f1

from .utils import get_intervals
Expand Down Expand Up @@ -285,4 +286,4 @@ def get_beta_softlabels(J, params_set="standard"):
raise ValueError(f"Invalid params_set: {params_set}")

params = _beta_params_sets[params_set]
return [_get_beta_softlabel(J, p, q, a) for (p, q, a) in params[J]]
return np.array([_get_beta_softlabel(J, p, q, a) for (p, q, a) in params[J]])
5 changes: 3 additions & 2 deletions dlordinal/soft_labelling/tests/test_beta_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def test_beta_softlabels():
n = 5
result = get_beta_softlabels(n)

assert len(result) == n
assert len(result[0]) == n
assert len(result.shape) == 2
assert result.shape[0] == n
assert result.shape[1] == n

expected_result = [
[
Expand Down

0 comments on commit e2b0c83

Please sign in to comment.