Skip to content

Commit 9730d9b

Browse files
committed
[GLE-8861] address comments;
1 parent 5160522 commit 9730d9b

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

gds/vector/cosine_distance.gsql

+12-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ CREATE FUNCTION gds.vector.cosine_distance(list<double> list1, list<double> list
2828
Exceptions:
2929
list_size_mismatch (90000):
3030
Raised when the input lists are not of equal size.
31+
zero_divisor(90001);
32+
Raised either list is all zero to avoid zero-divisor issue.
3133

3234
Logic Overview:
3335
Validates that both input vectors have the same length.
@@ -42,15 +44,24 @@ CREATE FUNCTION gds.vector.cosine_distance(list<double> list1, list<double> list
4244
*/
4345

4446
EXCEPTION list_size_mismatch (90000);
47+
EXCEPTION zero_divisor(90001);
4548
ListAccum<double> @@myList1 = list1;
4649
ListAccum<double> @@myList2 = list2;
4750

4851
IF (@@myList1.size() != @@myList2.size()) THEN
4952
RAISE list_size_mismatch ("Two lists provided for gds.vector.cosine_distance have different sizes.");
5053
END;
5154

52-
double innerP = inner_product(@@myList1, @@myList2);
55+
double inner_p = inner_product(@@myList1, @@myList2);
5356
double v1_magn = sqrt(inner_product(@@myList1, @@myList1));
5457
double v2_magn = sqrt(inner_product(@@myList2, @@myList2));
58+
IF (abs(v1_magn) < 0.0000001) THEN
59+
// use a small positive float to avoid numeric comparison error
60+
RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor.");
61+
END;
62+
IF (abs(v1_magn) < 0.0000001) THEN
63+
// use a small positive float to avoid numeric comparison error
64+
RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor.");
65+
END;
5566
RETURN (1 - innerP / (v1_magn * v2_magn));
5667
}

gds/vector/distance.gsql

+17-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
3232
Exceptions:
3333
list_size_mismatch (90000):
3434
Raised when the input vectors are not of equal size.
35-
invalid_metric_type (90001):
35+
zero_divisor(90001);
36+
Raised either list is all zero to avoid zero-divisor issue.
37+
invalid_metric_type (90002):
3638
Raised when an unsupported distance metric is provided.
3739

3840
Logic Overview:
@@ -55,7 +57,8 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
5557
*/
5658

5759
EXCEPTION list_size_mismatch (90000);
58-
EXCEPTION invalid_metric_type (90001);
60+
EXCEPTION zero_divisor(90001);
61+
EXCEPTION invalid_metric_type (90002);
5962
ListAccum<double> @@myList1 = list1;
6063
ListAccum<double> @@myList2 = list2;
6164

@@ -68,7 +71,18 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
6871

6972
CASE lower(metric)
7073
WHEN "cosine" THEN
71-
@@myResult = 1 - inner_product(@@myList1, @@myList2) / (sqrt(inner_product(@@myList1, @@myList1)) * sqrt(inner_product(@@myList2, @@myList2)));
74+
double inner_p = inner_product(@@myList1, @@myList2);
75+
double v1_magn = sqrt(inner_product(@@myList1, @@myList1));
76+
double v2_magn = sqrt(inner_product(@@myList2, @@myList2));
77+
IF (abs(v1_magn) < 0.0000001) THEN
78+
// use a small positive float to avoid numeric comparison error
79+
RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor.");
80+
END;
81+
IF (abs(v2_magn) < 0.0000001) THEN
82+
// use a small positive float to avoid numeric comparison error
83+
RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor.");
84+
END;
85+
@@myResult = 1 - inner_p / (v1_magn * v2_magn);
7286
WHEN "l2" THEN
7387
FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO
7488
@@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i));

gds/vector/norm.gsql

+3-10
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,16 @@ CREATE FUNCTION gds.vector.norm(list<double> list1, string metric) RETURNS(float
5353

5454
EXCEPTION invalid_metric_type (90001);
5555
ListAccum<double> @@myList1 = list1;
56-
ListAccum<double> @@myList2;
57-
58-
FOREACH i IN RANGE [0, @@myList1.size() - 1] DO
59-
@@myList2 += 0;
60-
end;
6156

6257
SumAccum<float> @@myResult;
6358
SumAccum<float> @@sqrSum;
6459

6560
CASE lower(metric)
6661
WHEN "l2" THEN
67-
FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO
68-
@@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i));
69-
END;
70-
@@myResult = sqrt(@@sqrSum);
62+
@@myResult = sqrt(inner_product(@@myList1, @@myList1));
7163
WHEN "ip" THEN
72-
@@myResult = inner_product(@@myList1, @@myList2);
64+
// the result of inner product between any vector and all-zero vector should always be 0
65+
@@myResult = 0;
7366
ELSE
7467
RAISE invalid_metric_type ("Invalid metric algorithm provided, currently supported: l2 and ip.");
7568
END

0 commit comments

Comments
 (0)