Skip to content

Commit d988e4e

Browse files
amitkduttafacebook-github-bot
authored andcommitted
feat: Add cosine_similarity function for two arrays (facebookincubator#13311)
Summary: Similar to prestodb/presto#25056 Reviewed By: yuandagits Differential Revision: D74549208
1 parent e52b11a commit d988e4e

File tree

4 files changed

+100
-8
lines changed

4 files changed

+100
-8
lines changed

velox/docs/functions/presto/math.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ Mathematical Functions
3838

3939
SELECT cosine_similarity(MAP(ARRAY[], ARRAY[]), MAP(ARRAY['a', 'b'], ARRAY[2, 3])); -- NaN
4040

41+
.. function:: cosine_similarity(array(double), array(double)) -> double
42+
43+
Returns the `cosine similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_ between the vectors represented as array(double).
44+
If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError.
45+
46+
SELECT cosine_similarity(ARRAY[1], ARRAY[2]); -- 1.0
47+
48+
SELECT cosine_similarity(ARRAY[1.0, 2.0], ARRAY[NULL, 3.0]); -- NULL
49+
50+
SELECT cosine_similarity(ARRAY[], ARRAY[2, 3]); -- Throws VeloxUserError
51+
52+
SELECT cosine_similarity(ARRAY[], ARRAY[]); -- NaN
53+
4154
.. function:: degrees(x) -> double
4255

4356
Converts angle x in radians to degrees.

velox/functions/prestosql/Arithmetic.h

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -656,18 +656,18 @@ struct WilsonIntervalLowerFunction {
656656
};
657657

658658
template <typename T>
659-
struct CosineSimilarityFunction {
659+
struct CosineSimilarityFunctionMap {
660660
VELOX_DEFINE_FUNCTION_TYPES(T);
661661

662-
double normalizeMap(const null_free_arg_type<Map<Varchar, double>>& map) {
662+
double normalize(const null_free_arg_type<Map<Varchar, double>>& map) {
663663
double norm = 0.0;
664664
for (const auto& [key, value] : map) {
665665
norm += (value * value);
666666
}
667667
return std::sqrt(norm);
668668
}
669669

670-
double mapDotProduct(
670+
double dotProduct(
671671
const null_free_arg_type<Map<Varchar, double>>& leftMap,
672672
const null_free_arg_type<Map<Varchar, double>>& rightMap) {
673673
double result = 0.0;
@@ -689,20 +689,67 @@ struct CosineSimilarityFunction {
689689
return;
690690
}
691691

692-
double normLeftMap = normalizeMap(leftMap);
692+
double normLeftMap = normalize(leftMap);
693693
if (normLeftMap == 0.0) {
694694
result = std::numeric_limits<double>::quiet_NaN();
695695
return;
696696
}
697697

698-
double normRightMap = normalizeMap(rightMap);
698+
double normRightMap = normalize(rightMap);
699699
if (normRightMap == 0.0) {
700700
result = std::numeric_limits<double>::quiet_NaN();
701701
return;
702702
}
703703

704-
double dotProduct = mapDotProduct(leftMap, rightMap);
705-
result = dotProduct / (normLeftMap * normRightMap);
704+
double product = dotProduct(leftMap, rightMap);
705+
result = product / (normLeftMap * normRightMap);
706+
}
707+
};
708+
709+
template <typename T>
710+
struct CosineSimilarityFunctionArray {
711+
VELOX_DEFINE_FUNCTION_TYPES(T);
712+
713+
double normalize(const null_free_arg_type<Array<double>>& map) {
714+
double norm = 0.0;
715+
for (const auto value : map) {
716+
norm += (value * value);
717+
}
718+
return std::sqrt(norm);
719+
}
720+
721+
double dotProduct(
722+
const null_free_arg_type<Array<double>>& leftArray,
723+
const null_free_arg_type<Array<double>>& rightArray) {
724+
double result = 0.0;
725+
for (size_t i = 0; i < leftArray.size(); i++) {
726+
result += leftArray[i] * rightArray[i];
727+
}
728+
return result;
729+
}
730+
731+
void callNullFree(
732+
out_type<double>& result,
733+
const null_free_arg_type<Array<double>>& leftArray,
734+
const null_free_arg_type<Array<double>>& rightArray) {
735+
VELOX_USER_CHECK(
736+
leftArray.size() == rightArray.size(),
737+
"Both arrays need to have identical size");
738+
739+
double normLeftArray = normalize(leftArray);
740+
if (normLeftArray == 0.0) {
741+
result = std::numeric_limits<double>::quiet_NaN();
742+
return;
743+
}
744+
745+
double normRightArray = normalize(rightArray);
746+
if (normRightArray == 0.0) {
747+
result = std::numeric_limits<double>::quiet_NaN();
748+
return;
749+
}
750+
751+
double product = dotProduct(leftArray, rightArray);
752+
result = product / (normLeftArray * normRightArray);
706753
}
707754
};
708755

velox/functions/prestosql/registration/MathematicalFunctionsRegistration.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,15 @@ void registerMathFunctions(const std::string& prefix) {
114114
registerTruncate({prefix + "truncate"});
115115

116116
registerFunction<
117-
CosineSimilarityFunction,
117+
CosineSimilarityFunctionMap,
118118
double,
119119
Map<Varchar, double>,
120120
Map<Varchar, double>>({prefix + "cosine_similarity"});
121+
registerFunction<
122+
CosineSimilarityFunctionArray,
123+
double,
124+
Array<double>,
125+
Array<double>>({prefix + "cosine_similarity"});
121126
}
122127

123128
} // namespace

velox/functions/prestosql/tests/ArithmeticTest.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,5 +1088,32 @@ TEST_F(ArithmeticTest, cosineSimilarity) {
10881088
.has_value());
10891089
}
10901090

1091+
TEST_F(ArithmeticTest, cosineSimilarityArray) {
1092+
const auto cosineSimilarity = [&](const std::vector<double>& left,
1093+
const std::vector<double>& right) {
1094+
auto leftMap = makeArrayVector<double>({left});
1095+
auto rightMap = makeArrayVector<double>({right});
1096+
return evaluateOnce<double>(
1097+
"cosine_similarity(c0,c1)", makeRowVector({leftMap, rightMap}))
1098+
.value();
1099+
};
1100+
1101+
EXPECT_DOUBLE_EQ(
1102+
(1 * 1 * 1 + 2 * 3) / (std::sqrt(5.0) * std::sqrt(10.0)),
1103+
cosineSimilarity({{1, 2}}, {{1, 3}}));
1104+
1105+
EXPECT_DOUBLE_EQ(
1106+
(1 * 1 + 2 * 3 + (-1) * 5) /
1107+
(std::sqrt(1 + 4 + 1) * std::sqrt(1 + 9 + 25)),
1108+
cosineSimilarity({{1, 2, -1}}, {{1, 3, 5}}));
1109+
1110+
EXPECT_TRUE(std::isnan(cosineSimilarity({}, {})));
1111+
VELOX_ASSERT_THROW(
1112+
cosineSimilarity({1, 3}, {}), "Both arrays need to have identical size");
1113+
EXPECT_TRUE(std::isnan(cosineSimilarity({1, 3}, {0, 0})));
1114+
EXPECT_TRUE(std::isnan(cosineSimilarity({1, 3}, {kNan, 1})));
1115+
EXPECT_TRUE(std::isnan(cosineSimilarity({1, 3}, {kInf, 1})));
1116+
}
1117+
10911118
} // namespace
10921119
} // namespace facebook::velox

0 commit comments

Comments
 (0)