Skip to content

Commit

Permalink
[GLE-8861] address comments;
Browse files Browse the repository at this point in the history
  • Loading branch information
jue-yuan committed Dec 5, 2024
1 parent 5160522 commit 3d49227
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
15 changes: 13 additions & 2 deletions gds/vector/cosine_distance.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ CREATE FUNCTION gds.vector.cosine_distance(list<double> list1, list<double> list
Exceptions:
list_size_mismatch (90000):
Raised when the input lists are not of equal size.
zero_divisor(90001);
Raised either list is all zero to avoid zero-divisor issue.

Logic Overview:
Validates that both input vectors have the same length.
Expand All @@ -42,15 +44,24 @@ CREATE FUNCTION gds.vector.cosine_distance(list<double> list1, list<double> list
*/

EXCEPTION list_size_mismatch (90000);
EXCEPTION zero_divisor(90001);
ListAccum<double> @@myList1 = list1;
ListAccum<double> @@myList2 = list2;

IF (@@myList1.size() != @@myList2.size()) THEN
RAISE list_size_mismatch ("Two lists provided for gds.vector.cosine_distance have different sizes.");
END;

double innerP = inner_product(@@myList1, @@myList2);
double inner_p = inner_product(@@myList1, @@myList2);
double v1_magn = sqrt(inner_product(@@myList1, @@myList1));
double v2_magn = sqrt(inner_product(@@myList2, @@myList2));
RETURN (1 - innerP / (v1_magn * v2_magn));
IF (abs(v1_magn) < 0.0000001) THEN
// use a small positive float to avoid numeric comparison error
RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor.");
END;
IF (abs(v1_magn) < 0.0000001) THEN
// use a small positive float to avoid numeric comparison error
RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor.");
END;
RETURN (1 - inner_p / (v1_magn * v2_magn));
}
20 changes: 17 additions & 3 deletions gds/vector/distance.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
Exceptions:
list_size_mismatch (90000):
Raised when the input vectors are not of equal size.
invalid_metric_type (90001):
zero_divisor(90001);
Raised either list is all zero to avoid zero-divisor issue.
invalid_metric_type (90002):
Raised when an unsupported distance metric is provided.

Logic Overview:
Expand All @@ -55,7 +57,8 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
*/

EXCEPTION list_size_mismatch (90000);
EXCEPTION invalid_metric_type (90001);
EXCEPTION zero_divisor(90001);
EXCEPTION invalid_metric_type (90002);
ListAccum<double> @@myList1 = list1;
ListAccum<double> @@myList2 = list2;

Expand All @@ -68,7 +71,18 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri

CASE lower(metric)
WHEN "cosine" THEN
@@myResult = 1 - inner_product(@@myList1, @@myList2) / (sqrt(inner_product(@@myList1, @@myList1)) * sqrt(inner_product(@@myList2, @@myList2)));
double inner_p = inner_product(@@myList1, @@myList2);
double v1_magn = sqrt(inner_product(@@myList1, @@myList1));
double v2_magn = sqrt(inner_product(@@myList2, @@myList2));
IF (abs(v1_magn) < 0.0000001) THEN
// use a small positive float to avoid numeric comparison error
RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor.");
END;
IF (abs(v2_magn) < 0.0000001) THEN
// use a small positive float to avoid numeric comparison error
RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor.");
END;
@@myResult = 1 - inner_p / (v1_magn * v2_magn);
WHEN "l2" THEN
FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO
@@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i));
Expand Down
13 changes: 3 additions & 10 deletions gds/vector/norm.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,16 @@ CREATE FUNCTION gds.vector.norm(list<double> list1, string metric) RETURNS(float

EXCEPTION invalid_metric_type (90001);
ListAccum<double> @@myList1 = list1;
ListAccum<double> @@myList2;

FOREACH i IN RANGE [0, @@myList1.size() - 1] DO
@@myList2 += 0;
end;

SumAccum<float> @@myResult;
SumAccum<float> @@sqrSum;

CASE lower(metric)
WHEN "l2" THEN
FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO
@@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i));
END;
@@myResult = sqrt(@@sqrSum);
@@myResult = sqrt(inner_product(@@myList1, @@myList1));
WHEN "ip" THEN
@@myResult = inner_product(@@myList1, @@myList2);
// the result of inner product between any vector and all-zero vector should always be 0
@@myResult = 0;
ELSE
RAISE invalid_metric_type ("Invalid metric algorithm provided, currently supported: l2 and ip.");
END
Expand Down

0 comments on commit 3d49227

Please sign in to comment.