@@ -32,7 +32,9 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
32
32
Exceptions:
33
33
list_size_mismatch (90000):
34
34
Raised when the input vectors are not of equal size.
35
- invalid_metric_type (90001):
35
+ zero_divisor(90001);
36
+ Raised either list is all zero to avoid zero-divisor issue.
37
+ invalid_metric_type (90002):
36
38
Raised when an unsupported distance metric is provided.
37
39
38
40
Logic Overview:
@@ -55,7 +57,8 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
55
57
*/
56
58
57
59
EXCEPTION list_size_mismatch (90000);
58
- EXCEPTION invalid_metric_type (90001);
60
+ EXCEPTION zero_divisor(90001);
61
+ EXCEPTION invalid_metric_type (90002);
59
62
ListAccum<double> @@myList1 = list1;
60
63
ListAccum<double> @@myList2 = list2;
61
64
@@ -68,7 +71,18 @@ CREATE FUNCTION gds.vector.distance(list<double> list1, list<double> list2, stri
68
71
69
72
CASE lower(metric)
70
73
WHEN "cosine" THEN
71
- @@myResult = 1 - inner_product(@@myList1, @@myList2) / (sqrt(inner_product(@@myList1, @@myList1)) * sqrt(inner_product(@@myList2, @@myList2)));
74
+ double inner_p = inner_product(@@myList1, @@myList2);
75
+ double v1_magn = sqrt(inner_product(@@myList1, @@myList1));
76
+ double v2_magn = sqrt(inner_product(@@myList2, @@myList2));
77
+ IF (abs(v1_magn) < 0.0000001) THEN
78
+ // use a small positive float to avoid numeric comparison error
79
+ RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor.");
80
+ END;
81
+ IF (abs(v2_magn) < 0.0000001) THEN
82
+ // use a small positive float to avoid numeric comparison error
83
+ RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor.");
84
+ END;
85
+ @@myResult = 1 - inner_p / (v1_magn * v2_magn);
72
86
WHEN "l2" THEN
73
87
FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO
74
88
@@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i));
0 commit comments