Skip to content

Commit

Permalink
Merge pull request #6 from ayrna/triangular-fix
Browse files Browse the repository at this point in the history
Triangular fix
  • Loading branch information
franberchez authored Nov 23, 2023
2 parents 7bc2afe + e4dc5c4 commit 06ead9b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 104 deletions.
2 changes: 2 additions & 0 deletions dlordinal/datasets/tests/test_adience.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def test_track_progress(adience_instance):
def test_process_and_split(adience_instance, monkeypatch):
global temp_dir

assert isinstance(temp_dir, tempfile.TemporaryDirectory)

df1 = pd.DataFrame.from_dict(
{
"user_id": ["30601258@N03", "30601258@N03"],
Expand Down
112 changes: 11 additions & 101 deletions dlordinal/distributions/triangular_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,21 @@ def get_triangular_probabilities(n: int, alpha2: float = 0.01, verbose: int = 0)
Verbosity level, by default 0.
"""

print(f"Computing triangular probabilities for {n=} and {alpha2=}...")

def compute_alpha1(alpha2):
c_plus = (1 - 2 * alpha2) * (2 * alpha2 + math.sqrt(2 * alpha2))
c_minus = (1 - 2 * alpha2) * (2 * alpha2 - math.sqrt(2 * alpha2))

results = []

try:
results.append(pow((1 + math.sqrt(1 - 4 * c_plus)) / 2, 2))
except:
pass
try:
results.append(pow((1 - math.sqrt(1 - 4 * c_plus)) / 2, 2))
except:
pass
try:
results.append(pow((1 + math.sqrt(1 - 4 * c_minus)) / 2, 2))
except:
pass
try:
results.append(pow((1 - math.sqrt(1 - 4 * c_minus)) / 2, 2))
except:
pass

results.sort()
assert len(results) > 0 # and results[0] < alpha2
return results[0]
return pow((1 - math.sqrt(1 - 4 * c_minus)) / 2, 2)

def compute_alpha3(alpha2):
results = []
try:
c1 = (
pow((n - 1) / n, 2)
* (1 - 2 * alpha2)
* (math.sqrt(2 * alpha2) * (-1 + math.sqrt(2 * alpha2)))
)

# Solutions corresponding to the Equation 3
try:
results.append(pow((-1 + math.sqrt(1 + 4 * c1)) / 2, 2))
except:
pass
try:
results.append(pow((-1 - math.sqrt(1 + 4 * c1)) / 2, 2))
except:
pass
try:
results.append(pow((1 + math.sqrt(1 - 4 * c1)) / 2, 2))
except:
pass
try:
results.append(pow((1 - math.sqrt(1 - 4 * c1)) / 2, 2))
except:
pass

c2 = (
-pow((n - 1) / n, 2)
* (1 - 2 * alpha2)
* (math.sqrt(2 * alpha2) * (1 + math.sqrt(2 * alpha2)))
)
c1 = (
pow((n - 1) / n, 2)
* (1 - 2 * alpha2)
* (math.sqrt(2 * alpha2) * (-1 + math.sqrt(2 * alpha2)))
)

# Solutions corresponding to the Equation 4
try:
results.append(pow((-1 + math.sqrt(1 - 4 * c2)) / 2, 2))
except:
pass
try:
results.append(pow((-1 - math.sqrt(1 - 4 * c2)) / 2, 2))
except:
pass
try:
results.append(pow((-1 + math.sqrt(1 + 4 * c2)) / 2, 2))
except:
pass
try:
results.append(pow((-1 - math.sqrt(1 + 4 * c2)) / 2, 2))
except:
pass
except:
pass

results.sort()
assert len(results) > 0 and 0 < results[0] < 1
return results[0]
return pow((1 - math.sqrt(1 - 4 * c1)) / 2, 2)

alpha1 = compute_alpha1(alpha2)
alpha3 = compute_alpha3(alpha2)
Expand Down Expand Up @@ -143,26 +75,6 @@ def bj(n, j):
else (num1 - num2) / den
)

def nj(n):
num1 = 2 * alpha2
num2 = math.sqrt(2 * alpha2)
den = 2.0 * n * (1 - 2 * alpha2)

# +-
return (
(num1 + num2) / den if (num1 + num2) / den >= 0.0 else (num1 - num2) / den
)

def mj(n):
num1 = 2 * alpha2
num2 = math.sqrt(2 * alpha2)
den = 2.0 * n * (1 - 2 * alpha2)

# +-
return (
(num1 + num2) / den if (num1 + num2) / den >= 0.0 else (num1 - num2) / den
)

def aJ(n):
aJ_plus = 1.0 + 1.0 / (n * (math.sqrt(alpha3) - 1.0))
aJ_minus = 1.0 + 1.0 / (-n * (math.sqrt(alpha3) - 1.0))
Expand All @@ -174,11 +86,9 @@ def nJ(n):
return num / den

if verbose >= 3:
print(
f"{b1(n)=}, {m1(n)=}, {aJ(n)=}, {nJ(n)=}, {aj(n, 1)=}, {bj(n,1)=}, {nj(n)=}, {mj(n)=}"
)
print(f"{b1(n)=}, {m1(n)=}, {aJ(n)=}, {nJ(n)=}, {aj(n, 1)=}, {bj(n,1)=}")
for i in range(1, n + 1):
print(f"{i=} {aj(n, i)=}, {bj(n,i)=}, {nj(n)=}, {mj(n)=}")
print(f"{i=} {aj(n, i)=}, {bj(n,i)=}")

intervals = get_intervals(n)
probs = []
Expand Down
6 changes: 3 additions & 3 deletions dlordinal/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def triangular_cdf(x: float, a: float, b: float, c: float):
"""
if x <= a:
return 0
if a < x < c:
elif a < x < c:
return pow(x - a, 2) / ((b - a) * (c - a))
if c < x < b:
elif c < x < b:
return 1 - pow(b - x, 2) / ((b - a) * (b - c))
if b <= x:
else: # b <= x
return 1

0 comments on commit 06ead9b

Please sign in to comment.